mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-11 01:42:43 +00:00
Compare commits
108 Commits
multi-mode
...
codex/agen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f793ff870 | ||
|
|
4a96ef13d7 | ||
|
|
55f570261f | ||
|
|
289a7b807e | ||
|
|
822b0c99be | ||
|
|
bcf2851a85 | ||
|
|
a5a59bd8f0 | ||
|
|
32d2e7985a | ||
|
|
c4f8d5370b | ||
|
|
9e434f6a5a | ||
|
|
67dc819319 | ||
|
|
2d12274050 | ||
|
|
c727ba13ee | ||
|
|
6193dd5326 | ||
|
|
387a7d1cea | ||
|
|
869578eeed | ||
|
|
e68648ab74 | ||
|
|
da01002099 | ||
|
|
f5d66f389c | ||
|
|
82d89f78c6 | ||
|
|
6f49c5e32c | ||
|
|
41f2bd2f19 | ||
|
|
bfa2f672f9 | ||
|
|
a823c3ead1 | ||
|
|
bd7d378a9a | ||
|
|
dcec0c8ef3 | ||
|
|
6456b51dcf | ||
|
|
7cfe27e31e | ||
|
|
3c5f77f5a4 | ||
|
|
ab4d1dce01 | ||
|
|
80c928eb58 | ||
|
|
77528876b1 | ||
|
|
3bf53495f3 | ||
|
|
e4cfcda0bf | ||
|
|
475e8f6cdc | ||
|
|
945272c1d2 | ||
|
|
185b057483 | ||
|
|
ac89b42b38 | ||
|
|
e19198f1f2 | ||
|
|
45a4c5c28f | ||
|
|
7a3e7fad7a | ||
|
|
3a8ba15c8d | ||
|
|
67b7d115db | ||
|
|
0e6759135f | ||
|
|
a95e2fd99a | ||
|
|
10ad7f92da | ||
|
|
f9f8f56ec1 | ||
|
|
91ed204f7a | ||
|
|
e519490c85 | ||
|
|
93251cf558 | ||
|
|
c31338e9b7 | ||
|
|
1c32a83dc2 | ||
|
|
4a2ff7e0ef | ||
|
|
c3f8fad729 | ||
|
|
d50a5e0e27 | ||
|
|
697a679409 | ||
|
|
0c95650176 | ||
|
|
0d3a6b255b | ||
|
|
01748efe6a | ||
|
|
de6c4f4a51 | ||
|
|
689f61ce08 | ||
|
|
dec836a172 | ||
|
|
b6e623ef5c | ||
|
|
ec9e340656 | ||
|
|
885006cb7a | ||
|
|
472073cac0 | ||
|
|
5e61659e3a | ||
|
|
7b18949b63 | ||
|
|
efe51c108e | ||
|
|
c092d16c01 | ||
|
|
da715eaa58 | ||
|
|
bb18d39765 | ||
|
|
abc2cd5572 | ||
|
|
a704acbf73 | ||
|
|
8737122133 | ||
|
|
c5d7cfa896 | ||
|
|
297c931191 | ||
|
|
ae343c718b | ||
|
|
ce39442478 | ||
|
|
256996f27c | ||
|
|
9dbe7acac6 | ||
|
|
8d43d73f83 | ||
|
|
559bac9f78 | ||
|
|
e81bbe6f69 | ||
|
|
b59f8cf453 | ||
|
|
456ecc7b9a | ||
|
|
fdc2bc9ee2 | ||
|
|
1c3f371549 | ||
|
|
a120add37b | ||
|
|
757e4e979b | ||
|
|
cbcdfee56e | ||
|
|
b06700314b | ||
|
|
01f573cdcb | ||
|
|
d4a96d70f3 | ||
|
|
5b000c2173 | ||
|
|
d62af28e40 | ||
|
|
593678a14f | ||
|
|
e6f7c2b45c | ||
|
|
f77128d929 | ||
|
|
1d4ca769e7 | ||
|
|
e002f6c195 | ||
|
|
10d696262f | ||
|
|
608e151443 | ||
|
|
41d1a33093 | ||
|
|
f396ebbdbb | ||
|
|
67c8df002e | ||
|
|
722f7de335 | ||
|
|
df14bbe0e2 |
@@ -1,186 +0,0 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
1
.cursor/skills/onyx-cli/SKILL.md
Symbolic link
1
.cursor/skills/onyx-cli/SKILL.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../cli/internal/embedded/SKILL.md
|
||||
65
.devcontainer/Dockerfile
Normal file
65
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,65 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
acl \
|
||||
curl \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" > /etc/apt/sources.list.d/docker.list \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends docker-ce-cli docker-compose-plugin gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
126
.devcontainer/README.md
Normal file
126
.devcontainer/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- Docker CLI, GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### VS Code
|
||||
|
||||
1. Install the [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)
|
||||
2. Open this repo in VS Code
|
||||
3. "Reopen in Container" when prompted
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
If you don't have `ods` installed, use the `devcontainer` CLI directly:
|
||||
|
||||
```bash
|
||||
npm install -g @devcontainers/cli
|
||||
|
||||
devcontainer up --workspace-folder .
|
||||
devcontainer exec --workspace-folder . zsh
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
### VS Code
|
||||
|
||||
Open the Command Palette (`Ctrl+Shift+P` / `Cmd+Shift+P`) and run:
|
||||
|
||||
- **Dev Containers: Reopen in Container** — restarts the container without rebuilding
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
Or without `ods`:
|
||||
|
||||
```bash
|
||||
devcontainer up --workspace-folder . --remove-existing-container
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure `dev` has
|
||||
read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. The init script grants `dev` access via
|
||||
POSIX ACLs (`setfacl`), which adds a few seconds to the first container start on
|
||||
large repos.
|
||||
|
||||
## Docker socket
|
||||
|
||||
The container mounts the host's Docker socket so you can run `docker` commands
|
||||
from inside. `ods dev` auto-detects the socket path and sets `DOCKER_SOCK`:
|
||||
|
||||
| Environment | Socket path |
|
||||
| ----------------------- | ------------------------------ |
|
||||
| Linux (rootless Docker) | `$XDG_RUNTIME_DIR/docker.sock` |
|
||||
| macOS (Docker Desktop) | `~/.docker/run/docker.sock` |
|
||||
| Linux (standard Docker) | `/var/run/docker.sock` |
|
||||
|
||||
To override, set `DOCKER_SOCK` before running `ods dev up`. When using the
|
||||
VS Code extension or `devcontainer` CLI directly (without `ods`), you must set
|
||||
`DOCKER_SOCK` yourself.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
22
.devcontainer/devcontainer.json
Normal file
22
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"mounts": [
|
||||
"source=${localEnv:DOCKER_SOCK},target=/var/run/docker.sock,type=bind",
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.ssh,target=/home/dev/.ssh.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim.host,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"remoteUser": "dev",
|
||||
"updateRemoteUserUID": false,
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
105
.devcontainer/init-dev-user.sh
Normal file
105
.devcontainer/init-dev-user.sh
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. We can't remap
|
||||
# dev to UID 0 (that's root), so we grant access with
|
||||
# POSIX ACLs instead.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
DEV_HOME=/home/"$TARGET_USER"
|
||||
|
||||
# Ensure directories that tools expect exist under ~dev.
|
||||
# ~/.local and ~/.cache are named Docker volumes -- ensure they are owned by dev.
|
||||
mkdir -p "$DEV_HOME"/.local/state "$DEV_HOME"/.local/share
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.local
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.cache
|
||||
|
||||
# Copy host configs mounted as *.host into their real locations.
|
||||
# This gives the dev user owned copies without touching host originals.
|
||||
if [ -d "$DEV_HOME/.ssh.host" ]; then
|
||||
cp -a "$DEV_HOME/.ssh.host" "$DEV_HOME/.ssh"
|
||||
chmod 700 "$DEV_HOME/.ssh"
|
||||
chmod 600 "$DEV_HOME"/.ssh/id_* 2>/dev/null || true
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.ssh"
|
||||
fi
|
||||
if [ -d "$DEV_HOME/.config/nvim.host" ]; then
|
||||
mkdir -p "$DEV_HOME/.config"
|
||||
cp -a "$DEV_HOME/.config/nvim.host" "$DEV_HOME/.config/nvim"
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.config/nvim"
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" /home/"$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to chown /home/$TARGET_USER" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned inside the container. Grant dev access
|
||||
# via POSIX ACLs (preserves ownership, works across the namespace
|
||||
# boundary).
|
||||
if command -v setfacl &>/dev/null; then
|
||||
setfacl -Rm "u:${TARGET_USER}:rwX" "$WORKSPACE"
|
||||
setfacl -Rdm "u:${TARGET_USER}:rwX" "$WORKSPACE" # default ACL for new files
|
||||
|
||||
# Git refuses to operate in repos owned by a different UID.
|
||||
# Host gitconfig is mounted readonly as ~/.gitconfig.host.
|
||||
# Create a real ~/.gitconfig that includes it plus container overrides.
|
||||
printf '[include]\n\tpath = %s/.gitconfig.host\n[safe]\n\tdirectory = %s\n' \
|
||||
"$DEV_HOME" "$WORKSPACE" > "$DEV_HOME/.gitconfig"
|
||||
chown "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.gitconfig"
|
||||
|
||||
# If this is a worktree, the main .git dir is bind-mounted at its
|
||||
# host absolute path. Grant dev access so git operations work.
|
||||
GIT_COMMON_DIR=$(git -C "$WORKSPACE" rev-parse --git-common-dir 2>/dev/null || true)
|
||||
if [ -n "$GIT_COMMON_DIR" ] && [ "$GIT_COMMON_DIR" != "$WORKSPACE/.git" ]; then
|
||||
[ ! -d "$GIT_COMMON_DIR" ] && GIT_COMMON_DIR="$WORKSPACE/$GIT_COMMON_DIR"
|
||||
if [ -d "$GIT_COMMON_DIR" ]; then
|
||||
setfacl -Rm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
|
||||
setfacl -Rdm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
|
||||
git config -f "$DEV_HOME/.gitconfig" --add safe.directory "$(dirname "$GIT_COMMON_DIR")"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also fix bind-mounted dirs under ~dev that appear root-owned.
|
||||
for dir in /home/"$TARGET_USER"/.claude; do
|
||||
[ -d "$dir" ] && setfacl -Rm "u:${TARGET_USER}:rwX" "$dir" && setfacl -Rdm "u:${TARGET_USER}:rwX" "$dir"
|
||||
done
|
||||
[ -f /home/"$TARGET_USER"/.claude.json ] && \
|
||||
setfacl -m "u:${TARGET_USER}:rw" /home/"$TARGET_USER"/.claude.json
|
||||
else
|
||||
echo "warning: setfacl not found; dev user may not have write access to workspace" >&2
|
||||
echo " install the 'acl' package or set remoteUser to root" >&2
|
||||
fi
|
||||
fi
|
||||
104
.devcontainer/init-firewall.sh
Executable file
104
.devcontainer/init-firewall.sh
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Preserve docker dns resolution
|
||||
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
|
||||
|
||||
# Flush all rules
|
||||
iptables -t nat -F
|
||||
iptables -t nat -X
|
||||
iptables -t mangle -F
|
||||
iptables -t mangle -X
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Restore docker dns rules
|
||||
if [ -n "$DOCKER_DNS_RULES" ]; then
|
||||
echo "$DOCKER_DNS_RULES" | iptables-restore -n
|
||||
fi
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Detect host network
|
||||
if [[ "${DOCKER_HOST:-}" == "unix://"* ]]; then
|
||||
DOCKER_GATEWAY=$(ip -4 route show | grep "^default" | awk '{print $3}')
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
10
.devcontainer/zshrc
Normal file
10
.devcontainer/zshrc
Normal file
@@ -0,0 +1,10 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') || github.ref_name == 'edge' }}
|
||||
|
||||
jobs:
|
||||
# Determine which components to build based on the tag
|
||||
@@ -156,7 +156,7 @@ jobs:
|
||||
check-version-tag:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.ref_name != 'edge' && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -59,3 +59,6 @@ node_modules
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
# Added context for LLMs
|
||||
onyx-llm-context/
|
||||
|
||||
@@ -9,7 +9,6 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -18,7 +17,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -31,7 +30,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -44,7 +43,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -57,7 +56,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -531,8 +531,7 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync",
|
||||
"--all-extras"
|
||||
"sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
416
AGENTS.md
416
AGENTS.md
@@ -1,361 +1,55 @@
|
||||
# PROJECT KNOWLEDGE BASE
|
||||
|
||||
This file provides guidance to AI agents when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
**Task Time Limits**:
|
||||
Since all tasks are executed in thread pools, the time limit features of Celery are silently
|
||||
disabled and won't work. Timeout logic must be implemented within the task itself.
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
|
||||
```bash
|
||||
# Create migration
|
||||
alembic revision -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision -m "description"
|
||||
```
|
||||
|
||||
Write the migration manually and place it in the file that alembic creates when running the above command.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
First, you must activate the virtual environment with `source .venv/bin/activate`.
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, _including_ the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## Creating a Plan
|
||||
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: _Timeline_, _Rollback plan_
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found in the "Engineering Best Practices" section of
|
||||
`CONTRIBUTING.md`. Understand its contents and follow them.
|
||||
# Project Knowledge Base
|
||||
|
||||
This file is the entrypoint for agents working in this repository. Keep it small.
|
||||
|
||||
## Start Here
|
||||
|
||||
- General development workflow and repo conventions: [CONTRIBUTING.md](./CONTRIBUTING.md)
|
||||
- Frontend standards for `web/` and `desktop/`: [web/AGENTS.md](./web/AGENTS.md)
|
||||
- Backend testing strategy and commands: [backend/tests/README.md](./backend/tests/README.md)
|
||||
- Celery worker and task guidance: [backend/onyx/background/celery/README.md](./backend/onyx/background/celery/README.md)
|
||||
- Backend API error-handling rules: [backend/onyx/error_handling/README.md](./backend/onyx/error_handling/README.md)
|
||||
- Plan-writing guidance: [plans/README.md](./plans/README.md)
|
||||
|
||||
## Agent-Lab Docs
|
||||
|
||||
When working on `agent-lab` or on tasks explicitly about agent-engineering, use:
|
||||
|
||||
- [docs/agent/README.md](./docs/agent/README.md)
|
||||
|
||||
These docs are the system of record for the `agent-lab` workflow.
|
||||
|
||||
## Universal Notes
|
||||
|
||||
- For non-trivial work, create the target worktree first and keep the edit, test, and PR loop
|
||||
inside that worktree. Do not prototype in one checkout and copy the patch into another unless
|
||||
you are explicitly debugging the harness itself.
|
||||
- Use `ods worktree create` for harness-managed worktrees. Do not use raw `git worktree add` when
|
||||
you want the `agent-lab` workflow, because it will skip the manifest, env overlays, dependency
|
||||
bootstrap, and lane-aware base-ref selection.
|
||||
- When a change needs browser proof, use the harness journey flow instead of ad hoc screen capture:
|
||||
record `before` in the target worktree before making the change, then record `after` in that
|
||||
same worktree after validation. Use `ods journey compare` only when you need to recover a missed
|
||||
baseline or compare two explicit revisions after the fact.
|
||||
- After opening a PR, treat review feedback and failing checks as part of the same loop:
|
||||
use `ods pr-review ...` for GitHub review threads and `ods pr-checks diagnose` plus `ods trace`
|
||||
for failing Playwright runs.
|
||||
- PR titles and commit messages should use conventional-commit style such as `fix: ...` or
|
||||
`feat: ...`. Do not use `[codex]` prefixes in this repo.
|
||||
- If Python dependencies appear missing, activate the root venv with `source .venv/bin/activate`.
|
||||
- To make tests work, check the root `.env` file for an OpenAI key.
|
||||
- If using Playwright to explore the frontend, you can usually log in with username `a@example.com`
|
||||
and password `a` at `http://localhost:3000`.
|
||||
- Assume Onyx services are already running unless the task indicates otherwise. Check `backend/log`
|
||||
if you need to verify service activity.
|
||||
- When making backend calls in local development flows, go through the frontend proxy:
|
||||
`http://localhost:3000/api/...`, not `http://localhost:8080/...`.
|
||||
- Put DB operations under `backend/onyx/db/` or `backend/ee/onyx/db/`. Do not add ad hoc DB access
|
||||
elsewhere.
|
||||
|
||||
## How To Use This File
|
||||
|
||||
- Use this file as a map, not a manual.
|
||||
- Follow the nearest authoritative doc for the subsystem you are changing.
|
||||
- If a repeated rule matters enough to teach every future agent, document it near the code it
|
||||
governs or encode it mechanically.
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
uv sync
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
@@ -19,7 +19,6 @@ from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
@@ -45,8 +44,6 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -56,25 +53,6 @@ if USE_IAM_AUTH:
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -231,7 +209,6 @@ def do_run_migrations(
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -405,7 +382,6 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -447,7 +423,6 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -490,7 +465,6 @@ def run_migrations_online() -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -0,0 +1,104 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -1,11 +1,9 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
@@ -35,27 +33,6 @@ target_metadata = [PublicBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -85,7 +62,6 @@ def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -10,9 +10,10 @@ from fastapi import status
|
||||
from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from ee.onyx.server.seeding import get_seed_config
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -39,7 +40,7 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
async def current_cloud_superuser(
|
||||
request: Request,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> User:
|
||||
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
if api_key != SUPER_CLOUD_API_KEY:
|
||||
|
||||
@@ -5,6 +5,7 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -30,6 +31,7 @@ def cloud_beat_task_generator(
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -48,20 +50,22 @@ def cloud_beat_task_generator(
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
|
||||
# connector deletion to run successfully. The new plan is to continously prune
|
||||
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
|
||||
# Keeping this around in case we want to revert to the previous behavior.
|
||||
# gated_tenants = get_gated_tenants()
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
# do no useful work on gated tenants and just waste DB connections fanning out
|
||||
# to ~10k+ inactive tenants. A small number of cleanup tasks (connector deletion,
|
||||
# checkpoint/index attempt cleanup) need to run on gated tenants and pass
|
||||
# `skip_gated=False` from the beat schedule.
|
||||
gated_tenants: set[str] = get_gated_tenants() if skip_gated else set()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
if tenant_id in gated_tenants:
|
||||
num_skipped_gated += 1
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
@@ -104,6 +108,7 @@ def cloud_beat_task_generator(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -27,13 +27,13 @@ from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Maximum tenants to provision in a single task run.
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
# Each tenant takes ~80s (alembic migrations), so 15 tenants ≈ 20 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 15
|
||||
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 40 # 40 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 45 # 45 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -29,59 +23,42 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete chat session "
|
||||
f"user_id={user_id} session_id={session_id}, "
|
||||
"continuing with remaining sessions"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -36,13 +36,16 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -280,7 +283,9 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -521,6 +526,22 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -36,6 +39,8 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -255,6 +260,7 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -269,6 +275,7 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -276,6 +283,8 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -286,6 +295,7 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -295,6 +305,8 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -478,6 +490,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -489,6 +511,8 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -796,6 +820,10 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -835,6 +863,19 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -863,6 +904,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
@@ -908,3 +953,46 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
|
||||
def set_group_permission__no_commit(
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
enabled: bool,
|
||||
granted_by: UUID,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Grant or revoke a single permission for a group using soft-delete.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
existing = db_session.execute(
|
||||
select(PermissionGrant)
|
||||
.where(
|
||||
PermissionGrant.group_id == group_id,
|
||||
PermissionGrant.permission == permission,
|
||||
)
|
||||
.with_for_update()
|
||||
).scalar_one_or_none()
|
||||
|
||||
if enabled:
|
||||
if existing is not None:
|
||||
if existing.is_deleted:
|
||||
existing.is_deleted = False
|
||||
existing.granted_by = granted_by
|
||||
existing.granted_at = func.now()
|
||||
else:
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=GrantSource.USER,
|
||||
granted_by=granted_by,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if existing is not None and not existing.is_deleted:
|
||||
existing.is_deleted = True
|
||||
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group_id, db_session)
|
||||
|
||||
@@ -155,7 +155,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
# Each endpoint is protected by admin permission checks.
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -17,10 +17,10 @@ from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
@@ -40,7 +40,7 @@ class QueryAnalyticsResponse(BaseModel):
|
||||
def get_query_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryAnalyticsResponse]:
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
@@ -71,7 +71,7 @@ class UserAnalyticsResponse(BaseModel):
|
||||
def get_user_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserAnalyticsResponse]:
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
@@ -105,7 +105,7 @@ class OnyxbotAnalyticsResponse(BaseModel):
|
||||
def get_onyxbot_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OnyxbotAnalyticsResponse]:
|
||||
daily_onyxbot_info = fetch_onyxbot_analytics(
|
||||
@@ -141,7 +141,7 @@ def get_persona_messages(
|
||||
persona_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaMessageAnalyticsResponse]:
|
||||
"""Fetch daily message counts for a single persona within the given time range."""
|
||||
@@ -179,7 +179,7 @@ def get_persona_unique_users(
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaUniqueUsersResponse]:
|
||||
"""Get unique users per day for a single persona."""
|
||||
@@ -218,7 +218,7 @@ def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
|
||||
@@ -29,7 +29,6 @@ from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
@@ -51,11 +50,13 @@ from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
@@ -147,7 +148,7 @@ def _get_tenant_id() -> str | None:
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
@@ -191,7 +192,7 @@ async def create_checkout_session(
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
@@ -216,7 +217,7 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
@@ -258,7 +259,7 @@ async def get_billing_information(
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
@@ -364,7 +365,7 @@ class ResetConnectionResponse(BaseModel):
|
||||
|
||||
@router.post("/reset-connection")
|
||||
async def reset_stripe_connection(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> ResetConnectionResponse:
|
||||
"""Reset the Stripe connection circuit breaker.
|
||||
|
||||
|
||||
@@ -27,11 +27,12 @@ from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.models import ScimTokenCreate
|
||||
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
|
||||
from ee.onyx.server.scim.models import ScimTokenResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
@@ -120,7 +121,8 @@ async def refresh_access_token(
|
||||
|
||||
@admin_router.put("")
|
||||
def admin_ee_put_settings(
|
||||
settings: EnterpriseSettings, _: User = Depends(current_admin_user)
|
||||
settings: EnterpriseSettings,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
store_settings(settings)
|
||||
|
||||
@@ -139,7 +141,7 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
@@ -196,7 +198,8 @@ def fetch_logo(
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
def upload_custom_analytics_script(
|
||||
script_upload: AnalyticsScriptUpload, _: User = Depends(current_admin_user)
|
||||
script_upload: AnalyticsScriptUpload,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
try:
|
||||
store_analytics_script(script_upload)
|
||||
@@ -220,7 +223,7 @@ def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
|
||||
|
||||
@admin_router.get("/scim/token")
|
||||
def get_active_scim_token(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenResponse:
|
||||
"""Return the currently active SCIM token's metadata, or 404 if none."""
|
||||
@@ -250,7 +253,7 @@ def get_active_scim_token(
|
||||
@admin_router.post("/scim/token", status_code=201)
|
||||
def create_scim_token(
|
||||
body: ScimTokenCreate,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
dal: ScimDAL = Depends(_get_scim_dal),
|
||||
) -> ScimTokenCreatedResponse:
|
||||
"""Create a new SCIM bearer token.
|
||||
|
||||
@@ -4,12 +4,13 @@ from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
@@ -178,7 +179,7 @@ router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
@@ -199,7 +200,7 @@ def get_hook_point_specs(
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
@@ -210,7 +211,7 @@ def list_hooks(
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -246,7 +247,7 @@ def create_hook(
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -258,7 +259,7 @@ def get_hook(
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -328,7 +329,7 @@ def update_hook(
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
@@ -339,7 +340,7 @@ def delete_hook(
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -381,7 +382,7 @@ def activate_hook(
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
@@ -409,7 +410,7 @@ def validate_hook(
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
@@ -432,7 +433,7 @@ def deactivate_hook(
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
|
||||
@@ -17,7 +17,6 @@ from fastapi import File
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
@@ -32,8 +31,10 @@ from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -60,7 +61,7 @@ def _strip_pem_delimiters(content: str) -> str:
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
@@ -84,7 +85,7 @@ async def get_license_status(
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
@@ -107,7 +108,7 @@ async def get_seat_usage(
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
@@ -215,7 +216,7 @@ async def claim_license(
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
@@ -263,7 +264,7 @@ async def upload_license(
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
@@ -292,7 +293,7 @@ async def refresh_license_cache_endpoint(
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
|
||||
@@ -12,8 +12,9 @@ from ee.onyx.db.standard_answer import insert_standard_answer_category
|
||||
from ee.onyx.db.standard_answer import remove_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer_category
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.manage.models import StandardAnswerCategory
|
||||
@@ -27,7 +28,7 @@ router = APIRouter(prefix="/manage")
|
||||
def create_standard_answer(
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StandardAnswer:
|
||||
standard_answer_model = insert_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
@@ -43,7 +44,7 @@ def create_standard_answer(
|
||||
@router.get("/admin/standard-answer")
|
||||
def list_standard_answers(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[StandardAnswer]:
|
||||
standard_answer_models = fetch_standard_answers(db_session=db_session)
|
||||
return [
|
||||
@@ -57,7 +58,7 @@ def patch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StandardAnswer:
|
||||
existing_standard_answer = fetch_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -83,7 +84,7 @@ def patch_standard_answer(
|
||||
def delete_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
return remove_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -95,7 +96,7 @@ def delete_standard_answer(
|
||||
def create_standard_answer_category(
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category_model = insert_standard_answer_category(
|
||||
category_name=standard_answer_category_creation_request.name,
|
||||
@@ -107,7 +108,7 @@ def create_standard_answer_category(
|
||||
@router.get("/admin/standard-answer/category")
|
||||
def list_standard_answer_categories(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[StandardAnswerCategory]:
|
||||
standard_answer_category_models = fetch_standard_answer_categories(
|
||||
db_session=db_session
|
||||
@@ -123,7 +124,7 @@ def patch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StandardAnswerCategory:
|
||||
existing_standard_answer_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=standard_answer_category_id,
|
||||
|
||||
@@ -9,9 +9,10 @@ from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -24,7 +25,7 @@ logger = setup_logger()
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
@@ -15,7 +15,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
@@ -26,6 +26,7 @@ from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -146,7 +147,7 @@ class ConfluenceCloudOAuth:
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
@@ -258,7 +259,7 @@ def confluence_oauth_callback(
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
@@ -325,7 +326,7 @@ def confluence_oauth_finalize(
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
@@ -34,6 +34,7 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -114,7 +115,7 @@ class GoogleDriveOAuth:
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
@@ -18,6 +18,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
@@ -98,7 +99,7 @@ class SlackOAuth:
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -8,8 +8,9 @@ from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.onyx.server.query_and_chat.models import StandardAnswerResponse
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -22,7 +23,7 @@ basic_router = APIRouter(prefix="/query")
|
||||
def get_standard_answer(
|
||||
request: StandardAnswerRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StandardAnswerResponse:
|
||||
try:
|
||||
standard_answers = oneoff_standard_answers(
|
||||
|
||||
@@ -19,10 +19,11 @@ from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
@@ -39,7 +40,7 @@ router = APIRouter(prefix="/search")
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
@@ -79,7 +80,7 @@ def search_flow_classification(
|
||||
)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
@@ -129,7 +130,7 @@ def handle_send_search_message(
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
|
||||
@@ -20,7 +20,7 @@ from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QueryHistoryExport
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
@@ -39,6 +39,7 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.file_record import get_query_history_export_files
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -153,7 +154,7 @@ def snapshot_from_chat_session(
|
||||
@router.get("/admin/chat-sessions")
|
||||
def admin_get_chat_sessions(
|
||||
user_id: UUID,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
@@ -196,7 +197,7 @@ def get_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -234,7 +235,7 @@ def get_chat_session_history(
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -269,7 +270,7 @@ def get_chat_session_admin(
|
||||
|
||||
@router.get("/admin/query-history/list")
|
||||
def list_all_query_history_exports(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryHistoryExport]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -297,7 +298,7 @@ def list_all_query_history_exports(
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
@@ -344,7 +345,7 @@ def start_query_history_export(
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -378,7 +379,7 @@ def get_query_history_export_status(
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
@@ -12,10 +12,11 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.usage_export import get_all_usage_reports
|
||||
from ee.onyx.db.usage_export import get_usage_report_data
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -31,7 +32,7 @@ class GenerateUsageReportParams(BaseModel):
|
||||
@router.post("/admin/usage-report", status_code=204)
|
||||
def generate_report(
|
||||
params: GenerateUsageReportParams,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
# Validate period parameters
|
||||
if params.period_from and params.period_to:
|
||||
@@ -58,7 +59,7 @@ def generate_report(
|
||||
@router.get("/admin/usage-report/{report_name}")
|
||||
def read_usage_report(
|
||||
report_name: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> Response:
|
||||
try:
|
||||
@@ -82,7 +83,7 @@ def read_usage_report(
|
||||
|
||||
@router.get("/admin/usage-report")
|
||||
def fetch_usage_reports(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UsageReportMetadata]:
|
||||
try:
|
||||
|
||||
@@ -52,16 +52,25 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -486,6 +495,7 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -506,13 +516,25 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -542,7 +564,8 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -556,6 +579,12 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -621,6 +650,7 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -649,6 +679,12 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -857,6 +893,11 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -879,8 +920,18 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -911,14 +962,36 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -961,8 +1034,19 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -973,10 +1057,15 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1002,11 +1091,21 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -12,12 +12,13 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -28,7 +29,7 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -44,7 +45,7 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
|
||||
@@ -22,7 +22,6 @@ import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
@@ -38,10 +37,12 @@ from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -99,7 +100,7 @@ def gate_product_full_sync(
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -108,7 +109,7 @@ async def billing_information(
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> dict:
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -130,7 +131,7 @@ async def create_customer_portal_session(
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> dict:
|
||||
"""Create a Stripe checkout session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -153,7 +154,7 @@ async def create_checkout_session(
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
@@ -6,10 +6,11 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
@@ -24,7 +25,9 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User = Depends(current_admin_user),
|
||||
current_user: User = Depends(
|
||||
require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -3,8 +3,9 @@ from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -26,7 +27,7 @@ FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> TenantByDomainResponse | None:
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
|
||||
@@ -10,9 +10,9 @@ from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -24,7 +24,7 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
@@ -37,7 +37,7 @@ async def request_invite(
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
@@ -46,7 +46,7 @@ def list_pending_users(
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
@@ -55,7 +55,7 @@ async def approve_user(
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
@@ -70,7 +70,7 @@ async def accept_invite(
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
|
||||
@@ -7,10 +7,11 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
@@ -28,7 +29,7 @@ Group Token Limit Settings
|
||||
|
||||
@router.get("/user-groups")
|
||||
def get_all_group_token_limit_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, list[TokenRateLimitDisplay]]:
|
||||
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
|
||||
@@ -64,7 +65,7 @@ def get_group_token_limit_settings(
|
||||
def create_group_token_limit_settings(
|
||||
group_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
@@ -86,7 +87,7 @@ User Token Limit Settings
|
||||
|
||||
@router.get("/users")
|
||||
def get_user_token_limit_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
@@ -98,7 +99,7 @@ def get_user_token_limit_settings(
|
||||
@router.post("/users")
|
||||
def create_user_token_limit_settings(
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
|
||||
@@ -13,22 +13,26 @@ from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import set_group_permission__no_commit
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionResponse
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import NON_TOGGLEABLE_PERMISSIONS
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -43,12 +47,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -56,31 +64,81 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
user: User = Depends(current_user),
|
||||
include_default: bool = False,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[Permission]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission for grant in group.permission_grants if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.put("/admin/user-group/{user_group_id}/permissions")
|
||||
def set_user_group_permission(
|
||||
user_group_id: int,
|
||||
request: SetPermissionRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SetPermissionResponse:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
|
||||
if request.permission in NON_TOGGLEABLE_PERMISSIONS:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Permission '{request.permission}' cannot be toggled via this endpoint",
|
||||
)
|
||||
|
||||
set_group_permission__no_commit(
|
||||
group_id=user_group_id,
|
||||
permission=request.permission,
|
||||
enabled=request.enabled,
|
||||
granted_by=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return SetPermissionResponse(permission=request.permission, enabled=request.enabled)
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -97,9 +155,12 @@ def create_user_group(
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -182,9 +243,12 @@ def set_user_curator(
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
@@ -200,7 +264,7 @@ def delete_user_group(
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
|
||||
@@ -2,6 +2,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.permissions import Permission
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.server.documents.models import ConnectorCredentialPairDescriptor
|
||||
from onyx.server.documents.models import ConnectorSnapshot
|
||||
@@ -22,6 +23,7 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -74,18 +76,21 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
@@ -117,3 +122,13 @@ class SetCuratorRequest(BaseModel):
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
|
||||
|
||||
class SetPermissionRequest(BaseModel):
|
||||
permission: Permission
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SetPermissionResponse(BaseModel):
|
||||
permission: Permission
|
||||
enabled: bool
|
||||
|
||||
125
backend/onyx/auth/permissions.py
Normal file
125
backend/onyx/auth/permissions.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
# Permissions that cannot be toggled via the group-permission API.
|
||||
# BASIC_ACCESS is always granted, FULL_ADMIN_PANEL_ACCESS is too broad,
|
||||
# and READ_* permissions are implied (never stored directly).
|
||||
NON_TOGGLEABLE_PERMISSIONS: frozenset[Permission] = frozenset(
|
||||
{
|
||||
Permission.BASIC_ACCESS,
|
||||
Permission.FULL_ADMIN_PANEL_ACCESS,
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
dependency._is_require_permission = True # type: ignore[attr-defined] # sentinel for auth_check detection
|
||||
return dependency
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -50,19 +53,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,7 +80,6 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -120,12 +119,15 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import is_limited_user
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
@@ -500,18 +502,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -525,18 +530,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -573,6 +581,38 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -694,6 +734,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -726,7 +767,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -743,14 +784,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -836,6 +888,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -975,7 +1037,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1471,7 +1533,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1492,7 +1554,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1554,6 +1616,7 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
@@ -1619,9 +1682,9 @@ async def current_user(
|
||||
) -> User:
|
||||
user = await double_check_user(user)
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
if is_limited_user(user):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
detail="Access denied. User has limited permissions.",
|
||||
)
|
||||
return user
|
||||
|
||||
@@ -1638,15 +1701,6 @@ async def current_curator_or_admin_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
if user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
@@ -1755,11 +1809,11 @@ async def current_user_from_websocket(
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
# Block limited users (same as current_user)
|
||||
if is_limited_user(user):
|
||||
logger.warning(f"WS auth: user {user.email} is limited")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
detail="Access denied. User has limited permissions.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
@@ -9,37 +10,41 @@ The background jobs take care of:
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
| Worker | File | Queues |
|
||||
| ------------------------- | ------------------------------ | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
| App | File | Purpose |
|
||||
| ---------- | ----------- | ----------------------------------------------------------------------------------------------------- |
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
|
||||
`app_base.py` provides:
|
||||
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
@@ -47,34 +52,34 @@ On startup:
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
| Task | Frequency | Description |
|
||||
| --------------------------------- | --------- | ------------------------------------------------------------------------------------------ |
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
|
||||
### Light
|
||||
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
|
||||
### Heavy
|
||||
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
@@ -83,16 +88,24 @@ Generates CSV exports which may take a long time with significant data in Postgr
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
- User Files come from uploads directly via the input bar
|
||||
|
||||
### Monitoring
|
||||
|
||||
Observability and metrics collections:
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
|
||||
- Queue lengths, connector success/failure, connector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
|
||||
## Prometheus Metrics
|
||||
|
||||
Workers can expose Prometheus metrics via a standalone HTTP server. Currently docfetching and docprocessing have push-based task lifecycle metrics; the monitoring worker runs pull-based collectors for queue depth and connector health.
|
||||
|
||||
For the full metric reference, integration guide, and PromQL examples, see [`docs/METRICS.md`](../../../docs/METRICS.md#celery-worker-metrics).
|
||||
|
||||
37
backend/onyx/background/celery/README.md
Normal file
37
backend/onyx/background/celery/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Celery Development Notes
|
||||
|
||||
This document is the local reference for Celery worker structure and task-writing rules in Onyx.
|
||||
|
||||
## Worker Types
|
||||
|
||||
Onyx uses multiple specialized workers:
|
||||
|
||||
1. `primary`: coordinates core background tasks and system-wide operations.
|
||||
2. `docfetching`: fetches documents from connectors and schedules downstream work.
|
||||
3. `docprocessing`: runs the indexing pipeline for fetched documents.
|
||||
4. `light`: handles lightweight and fast operations.
|
||||
5. `heavy`: handles more resource-intensive operations.
|
||||
6. `kg_processing`: runs knowledge-graph processing and clustering.
|
||||
7. `monitoring`: collects health and system metrics.
|
||||
8. `user_file_processing`: processes user-uploaded files.
|
||||
9. `beat`: schedules periodic work.
|
||||
|
||||
For actual implementation details, inspect:
|
||||
|
||||
- `backend/onyx/background/celery/apps/`
|
||||
- `backend/onyx/background/celery/configs/`
|
||||
- `backend/onyx/background/celery/tasks/`
|
||||
|
||||
## Task Rules
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`.
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks/`.
|
||||
- Never enqueue a task without `expires=`. This is a hard requirement because stale queued work can
|
||||
accumulate without bound.
|
||||
- Do not rely on Celery time-limit enforcement. These workers run in thread pools, so timeout logic
|
||||
must be implemented inside the task itself.
|
||||
|
||||
## Testing Note
|
||||
|
||||
If you change Celery worker code and want to validate it against a running local worker, the worker
|
||||
usually needs to be restarted manually. There is no general auto-restart on code change.
|
||||
@@ -13,6 +13,12 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -34,6 +40,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -48,6 +55,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -76,6 +108,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("heavy")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -317,7 +317,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
@@ -30,6 +31,8 @@ from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -130,6 +133,7 @@ def _extract_from_batch(
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
connector_type: str = "unknown",
|
||||
) -> SlimConnectorExtractionResult:
|
||||
"""
|
||||
Extract document IDs and hierarchy nodes from a runnable connector.
|
||||
@@ -179,21 +183,38 @@ def extract_ids_from_runnable_connector(
|
||||
)
|
||||
|
||||
# process raw batches to extract both IDs and hierarchy nodes
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
enumeration_start = time.monotonic()
|
||||
try:
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
except Exception as e:
|
||||
# Best-effort rate limit detection via string matching.
|
||||
# Connectors surface rate limits inconsistently — some raise HTTP 429,
|
||||
# some use SDK-specific exceptions (e.g. google.api_core.exceptions.ResourceExhausted)
|
||||
# that may or may not include "rate limit" or "429" in the message.
|
||||
# TODO(Bo): replace with a standard ConnectorRateLimitError exception that all
|
||||
# connectors raise when rate limited, making this check precise.
|
||||
error_str = str(e)
|
||||
if "rate limit" in error_str.lower() or "429" in error_str:
|
||||
inc_pruning_rate_limit_error(connector_type)
|
||||
raise
|
||||
finally:
|
||||
observe_pruning_enumeration_duration(
|
||||
time.monotonic() - enumeration_start, connector_type
|
||||
)
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
|
||||
@@ -75,6 +75,8 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale checkpoints to clean.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -84,6 +86,8 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale index attempts.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -93,6 +97,8 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -136,7 +142,14 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "cleanup-idle-sandboxes",
|
||||
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
|
||||
"schedule": timedelta(minutes=1),
|
||||
# SANDBOX_IDLE_TIMEOUT_SECONDS defaults to 1 hour, so there is no
|
||||
# functional reason to scan more often than every ~15 minutes. In the
|
||||
# cloud this is multiplied by CLOUD_BEAT_MULTIPLIER_DEFAULT (=8) so
|
||||
# the effective cadence becomes ~2 hours, which still meets the
|
||||
# idle-detection SLA. The previous 1-minute base schedule produced
|
||||
# an 8-minute per-tenant fan-out and was the dominant source of
|
||||
# background DB load on the cloud cluster.
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -266,7 +279,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -302,7 +315,7 @@ beat_cloud_tasks: list[dict] = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"schedule": timedelta(minutes=2),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
@@ -359,7 +372,13 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
|
||||
from onyx.db.opensearch_migration import get_vespa_visit_state
|
||||
from onyx.db.opensearch_migration import is_migration_completed
|
||||
from onyx.db.opensearch_migration import (
|
||||
mark_migration_completed_time_if_not_set_with_commit,
|
||||
)
|
||||
@@ -106,14 +107,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task errored.
|
||||
"""
|
||||
# 1. Check if we should run the task.
|
||||
# 1.a. If OpenSearch indexing is disabled, we don't run the task.
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping chunk migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
|
||||
# 1.b. Only one instance per tenant of this task may run concurrently at
|
||||
# once. If we fail to acquire a lock, we assume it is because another task
|
||||
# has one and we exit.
|
||||
r = get_redis_client()
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
@@ -136,10 +142,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
# 2. Prepare to migrate.
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
# 2.a. Double-check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
err_str = (
|
||||
f"Tenant ID mismatch in the OpenSearch migration task: "
|
||||
@@ -148,16 +155,62 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 2.b. Immediately check to see if this tenant is done, to save
|
||||
# having to do any other work. This function does not require a
|
||||
# migration record to necessarily exist.
|
||||
if is_migration_completed(db_session):
|
||||
return True
|
||||
|
||||
# 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it
|
||||
# does not exist.
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
|
||||
# 2.d. Get search settings.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
# We reconstruct this mapping for every task invocation because
|
||||
# a document may have been added in the time between two tasks.
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# 2.f. Get the current migration state.
|
||||
continuation_token_map, total_chunks_migrated = get_vespa_visit_state(
|
||||
db_session
|
||||
)
|
||||
# 2.f.1. Double-check that the migration state does not imply
|
||||
# completion. Really we should never have to enter this block as we
|
||||
# would expect is_migration_completed to return True, but in the
|
||||
# strange event that the migration is complete but the migration
|
||||
# completed time was never stamped, we do so here.
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client:
|
||||
# 2.g. Create the OpenSearch and Vespa document indexes.
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
@@ -171,22 +224,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
# We reconstruct this mapping for every task invocation because a
|
||||
# document may have been added in the time between two tasks.
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# 2.h. Get the approximate chunk count in Vespa as of this time to
|
||||
# update the migration record.
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
# This failure should not be blocking.
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
@@ -195,25 +240,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
# 3. Do the actual migration in batches until we run out of time.
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token_map,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
# 3.a. Get the next batch of raw chunks from Vespa.
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
@@ -226,6 +258,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
)
|
||||
|
||||
# 3.b. Transform the raw chunks to OpenSearch chunks in memory.
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
@@ -240,6 +273,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
"errored."
|
||||
)
|
||||
|
||||
# 3.c. Index the OpenSearch chunks into OpenSearch.
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
opensearch_document_index.index_raw_chunks(
|
||||
chunks=opensearch_document_chunks
|
||||
@@ -251,12 +285,38 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 3.d. Update the migration state.
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
# 3.e. Get the current migration state. Even thought we
|
||||
# technically have it in-memory since we just wrote it, we
|
||||
# want to reference the DB as the source of truth at all
|
||||
# times.
|
||||
continuation_token_map, total_chunks_migrated = (
|
||||
get_vespa_visit_state(db_session)
|
||||
)
|
||||
# 3.e.1. Check if the migration is done.
|
||||
if is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map
|
||||
):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
@@ -72,6 +72,7 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
@@ -217,7 +218,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
try:
|
||||
# the entire task needs to run frequently in order to finalize pruning
|
||||
|
||||
# but pruning only kicks off once per hour
|
||||
# but pruning only kicks off once per min
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
task_logger.info("Checking for pruning due")
|
||||
|
||||
@@ -570,8 +571,9 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
connector_type = cc_pair.connector.source.value
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
@@ -636,40 +638,46 @@ def connector_pruning_generator_task(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
diff_start = time.monotonic()
|
||||
try:
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
finally:
|
||||
observe_pruning_diff_duration(
|
||||
time.monotonic() - diff_start, connector_type
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
|
||||
@@ -996,6 +996,7 @@ def _run_models(
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
"""Run one LLM loop inside a worker thread, writing packets to ``merged_queue``."""
|
||||
|
||||
model_emitter = Emitter(
|
||||
model_idx=model_idx,
|
||||
merged_queue=merged_queue,
|
||||
@@ -1102,33 +1103,33 @@ def _run_models(
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
def _delete_orphaned_message(model_idx: int, context: str) -> None:
|
||||
"""Delete a reserved ChatMessage that was never populated due to a model error."""
|
||||
def _save_errored_message(model_idx: int, context: str) -> None:
|
||||
"""Save an error message to a reserved ChatMessage that failed during execution."""
|
||||
try:
|
||||
orphaned = db_session.get(
|
||||
ChatMessage, setup.reserved_messages[model_idx].id
|
||||
)
|
||||
if orphaned is not None:
|
||||
db_session.delete(orphaned)
|
||||
msg = db_session.get(ChatMessage, setup.reserved_messages[model_idx].id)
|
||||
if msg is not None:
|
||||
error_text = f"Error from {setup.model_display_names[model_idx]}: model encountered an error during generation."
|
||||
msg.message = error_text
|
||||
msg.error = error_text
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s orphan cleanup failed for model %d (%s)",
|
||||
"%s error save failed for model %d (%s)",
|
||||
context,
|
||||
model_idx,
|
||||
setup.model_display_names[model_idx],
|
||||
)
|
||||
|
||||
# Copy contextvars before submitting futures — ThreadPoolExecutor does NOT
|
||||
# auto-propagate contextvars in Python 3.11; threads would inherit a blank context.
|
||||
worker_context = contextvars.copy_context()
|
||||
# Each worker thread needs its own Context copy — a single Context object
|
||||
# cannot be entered concurrently by multiple threads (RuntimeError).
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models, thread_name_prefix="multi-model"
|
||||
)
|
||||
completion_persisted: bool = False
|
||||
try:
|
||||
for i in range(n_models):
|
||||
executor.submit(worker_context.run, _run_model, i)
|
||||
ctx = contextvars.copy_context()
|
||||
executor.submit(ctx.run, _run_model, i)
|
||||
|
||||
# ── Main thread: merge and yield packets ────────────────────────────
|
||||
models_remaining = n_models
|
||||
@@ -1145,7 +1146,7 @@ def _run_models(
|
||||
# save "stopped by user" for a model that actually threw an exception.
|
||||
for i in range(n_models):
|
||||
if model_errored[i]:
|
||||
_delete_orphaned_message(i, "stop-button")
|
||||
_save_errored_message(i, "stop-button")
|
||||
continue
|
||||
try:
|
||||
succeeded = model_succeeded[i]
|
||||
@@ -1211,7 +1212,7 @@ def _run_models(
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
# Model errored — delete its orphaned reserved message.
|
||||
_delete_orphaned_message(i, "normal")
|
||||
_save_errored_message(i, "normal")
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
@@ -1264,7 +1265,7 @@ def _run_models(
|
||||
setup.model_display_names[i],
|
||||
)
|
||||
elif model_errored[i]:
|
||||
_delete_orphaned_message(i, "disconnect")
|
||||
_save_errored_message(i, "disconnect")
|
||||
# 4. Drain buffered packets from memory — no consumer is running.
|
||||
while not merged_queue.empty():
|
||||
try:
|
||||
|
||||
@@ -379,6 +379,14 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
# Comma-separated replica / multi-host list. If unset, defaults to POSTGRES_HOST
|
||||
# only.
|
||||
_POSTGRES_HOSTS_STR = os.environ.get("POSTGRES_HOSTS", "").strip()
|
||||
POSTGRES_HOSTS: list[str] = (
|
||||
[h.strip() for h in _POSTGRES_HOSTS_STR.split(",") if h.strip()]
|
||||
if _POSTGRES_HOSTS_STR
|
||||
else [POSTGRES_HOST]
|
||||
)
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
|
||||
@@ -12,6 +12,11 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -391,10 +396,6 @@ class MilestoneRecordType(str, Enum):
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
@@ -577,7 +578,6 @@ class OnyxCeleryTask:
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
_MAX_PAGES = 1000
|
||||
|
||||
|
||||
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
|
||||
# TODO: Pages need to have their metadata ingested
|
||||
|
||||
|
||||
class NotionPage(BaseModel):
|
||||
@@ -452,6 +452,19 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
|
||||
type_name = sub_inner_dict["type"]
|
||||
|
||||
# Notion user objects (people properties, created_by, etc.) have
|
||||
# "name" at the same level as "type": "person"/"bot". If we drill
|
||||
# into the person/bot sub-dict we lose the name. Capture it here
|
||||
# before descending, but skip "title"-type properties where "name"
|
||||
# is not the display value we want.
|
||||
if (
|
||||
"name" in sub_inner_dict
|
||||
and isinstance(sub_inner_dict["name"], str)
|
||||
and type_name not in ("title",)
|
||||
):
|
||||
return sub_inner_dict["name"]
|
||||
|
||||
sub_inner_dict = sub_inner_dict[type_name]
|
||||
|
||||
# If the innermost layer is None, the value is not set
|
||||
@@ -663,6 +676,19 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
# table_row blocks store content in "cells" (list of lists
|
||||
# of rich text objects) rather than "rich_text"
|
||||
if "cells" in result_obj:
|
||||
row_cells: list[str] = []
|
||||
for cell in result_obj["cells"]:
|
||||
cell_texts = [
|
||||
rt.get("plain_text", "")
|
||||
for rt in cell
|
||||
if isinstance(rt, dict)
|
||||
]
|
||||
row_cells.append(" ".join(cell_texts))
|
||||
cur_result_text_arr.append("\t".join(row_cells))
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
# Child pages will not be included at this top level, it will be a separate document.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -10,14 +11,23 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -85,6 +95,7 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -97,7 +108,18 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# Only ADMIN and BASIC roles get default group membership.
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -124,7 +146,33 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (only for ADMIN/BASIC).
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -190,16 +190,23 @@ def delete_messages_and_files_from_chat_session(
|
||||
chat_session_id: UUID, db_session: Session
|
||||
) -> None:
|
||||
# Select messages older than cutoff_time with files
|
||||
messages_with_files = db_session.execute(
|
||||
select(ChatMessage.id, ChatMessage.files).where(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
messages_with_files = (
|
||||
db_session.execute(
|
||||
select(ChatMessage.id, ChatMessage.files).where(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
.tuples()
|
||||
.all()
|
||||
)
|
||||
|
||||
file_store = get_default_file_store()
|
||||
for _, files in messages_with_files:
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
if file_info.get("user_file_id"):
|
||||
# user files are managed by the user file lifecycle
|
||||
continue
|
||||
file_store.delete_file(file_id=file_info["id"], error_on_missing=False)
|
||||
|
||||
# Delete ChatMessage records - CASCADE constraints will automatically handle:
|
||||
# - ChatMessage__StandardAnswer relationship records
|
||||
|
||||
@@ -13,19 +13,26 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -8,6 +8,8 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -45,6 +47,23 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -66,6 +85,8 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -277,6 +298,8 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -236,14 +236,15 @@ def upsert_llm_provider(
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
# Filter out empty strings and None values from custom_config to allow
|
||||
# providers like Bedrock to fall back to IAM roles when credentials are not provided
|
||||
# providers like Bedrock to fall back to IAM roles when credentials are not provided.
|
||||
# NOTE: An empty dict ({}) is preserved as-is — it signals that the provider was
|
||||
# created via the custom modal and must be reopened with CustomModal, not a
|
||||
# provider-specific modal. Only None means "no custom config at all".
|
||||
custom_config = llm_provider_upsert_request.custom_config
|
||||
if custom_config:
|
||||
custom_config = {
|
||||
k: v for k, v in custom_config.items() if v is not None and v.strip() != ""
|
||||
}
|
||||
# Set to None if the dict is empty after filtering
|
||||
custom_config = custom_config or None
|
||||
|
||||
api_base = llm_provider_upsert_request.api_base or None
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
@@ -303,16 +304,7 @@ def upsert_llm_provider(
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.flush()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
for model_config in llm_provider_upsert_request.model_configurations:
|
||||
max_input_tokens = model_config.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=model_config.name,
|
||||
model_provider=llm_provider_upsert_request.provider,
|
||||
)
|
||||
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model_config.supports_image_input:
|
||||
@@ -325,7 +317,7 @@ def upsert_llm_provider(
|
||||
model_configuration_id=existing.id,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_config.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
max_input_tokens=model_config.max_input_tokens,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
else:
|
||||
@@ -335,7 +327,7 @@ def upsert_llm_provider(
|
||||
model_name=model_config.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_config.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
max_input_tokens=model_config.max_input_tokens,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -324,6 +324,15 @@ def mark_migration_completed_time_if_not_set_with_commit(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def is_migration_completed(db_session: Session) -> bool:
|
||||
"""Returns True if the migration is completed.
|
||||
|
||||
Can be run even if the migration record does not exist.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
return record is not None and record.migration_completed_at is not None
|
||||
|
||||
|
||||
def build_sanitized_to_original_doc_id_mapping(
|
||||
db_session: Session,
|
||||
) -> dict[str, str]:
|
||||
|
||||
95
backend/onyx/db/permissions.py
Normal file
95
backend/onyx/db/permissions.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,12 +9,18 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import is_limited_user
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -23,13 +29,56 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database."""
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group.
|
||||
# assign_user_to_default_groups__no_commit internally skips
|
||||
# ANONYMOUS, BOT, and EXT_PERM_USER account types.
|
||||
# Also skip limited users (no group assignment).
|
||||
if not is_limited_user(user):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -47,8 +96,19 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True."""
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
user.is_active = True
|
||||
# assign_user_to_default_groups__no_commit internally skips
|
||||
# ANONYMOUS, BOT, and EXT_PERM_USER account types.
|
||||
# Also skip limited users (no group assignment).
|
||||
if not is_limited_user(user):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -27,11 +28,35 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_limited_user(user: User) -> bool:
|
||||
"""Check if a user is effectively limited — i.e. should be denied
|
||||
access by ``current_user`` and should not receive default-group
|
||||
membership.
|
||||
|
||||
A user is limited when they are:
|
||||
* an anonymous user, or
|
||||
* a service account with no effective permissions (no group membership).
|
||||
"""
|
||||
if user.account_type == AccountType.ANONYMOUS:
|
||||
return True
|
||||
if (
|
||||
user.account_type == AccountType.SERVICE_ACCOUNT
|
||||
and not user.effective_permissions
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
requested_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -41,28 +66,27 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current role is a limited user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current account type is ANONYMOUS or SERVICE_ACCOUNT
|
||||
"""
|
||||
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
if current_account_type == AccountType.BOT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.LIMITED:
|
||||
if current_account_type in (AccountType.ANONYMOUS, AccountType.SERVICE_ACCOUNT):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
detail="Cannot change the role of an anonymous or service account user.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
@@ -298,6 +322,7 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -306,8 +331,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -344,6 +370,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -375,6 +402,81 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -421,13 +523,14 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
stmt = (
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -435,7 +538,11 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -20,9 +21,13 @@ from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
filter_and_validate_document_id,
|
||||
)
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH,
|
||||
)
|
||||
from onyx.utils.tenant import get_tenant_id_short_string
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -75,17 +80,50 @@ def get_opensearch_doc_chunk_id(
|
||||
|
||||
This will be the string used to identify the chunk in OpenSearch. Any direct
|
||||
chunk queries should use this function.
|
||||
|
||||
If the document ID is too long, a hash of the ID is used instead.
|
||||
"""
|
||||
sanitized_document_id = filter_and_validate_document_id(document_id)
|
||||
opensearch_doc_chunk_id = (
|
||||
f"{sanitized_document_id}__{max_chunk_size}__{chunk_index}"
|
||||
opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}"
|
||||
encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8"))
|
||||
max_encoded_permissible_doc_id_length: int = (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length
|
||||
)
|
||||
opensearch_doc_chunk_id_tenant_prefix: str = ""
|
||||
if tenant_state.multitenant:
|
||||
short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
# Use tenant ID because in multitenant mode each tenant has its own
|
||||
# Documents table, so there is a very small chance that doc IDs are not
|
||||
# actually unique across all tenants.
|
||||
short_tenant_id = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
opensearch_doc_chunk_id = f"{short_tenant_id}__{opensearch_doc_chunk_id}"
|
||||
opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__"
|
||||
encoded_prefix_length: int = len(
|
||||
opensearch_doc_chunk_id_tenant_prefix.encode("utf-8")
|
||||
)
|
||||
max_encoded_permissible_doc_id_length -= encoded_prefix_length
|
||||
|
||||
try:
|
||||
sanitized_document_id: str = filter_and_validate_document_id(
|
||||
document_id, max_encoded_length=max_encoded_permissible_doc_id_length
|
||||
)
|
||||
except DocumentIDTooLongError:
|
||||
# If the document ID is too long, use a hash instead.
|
||||
# We use blake2b because it is faster and equally secure as SHA256, and
|
||||
# accepts digest_size which controls the number of bytes returned in the
|
||||
# hash.
|
||||
# digest_size is the size of the returned hash in bytes. Since we're
|
||||
# decoding the hash bytes as a hex string, the digest_size should be
|
||||
# half the max target size of the hash string.
|
||||
# Subtract 1 because filter_and_validate_document_id compares on >= on
|
||||
# max_encoded_length.
|
||||
# 64 is the max digest_size blake2b returns.
|
||||
digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64)
|
||||
sanitized_document_id = hashlib.blake2b(
|
||||
document_id.encode("utf-8"), digest_size=digest_size
|
||||
).hexdigest()
|
||||
|
||||
opensearch_doc_chunk_id: str = (
|
||||
f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}"
|
||||
)
|
||||
|
||||
# Do one more validation to ensure we haven't exceeded the max length.
|
||||
opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id)
|
||||
return opensearch_doc_chunk_id
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import re
|
||||
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512
|
||||
|
||||
def filter_and_validate_document_id(document_id: str) -> str:
|
||||
|
||||
class DocumentIDTooLongError(ValueError):
|
||||
"""Raised when a document ID is too long for OpenSearch after filtering."""
|
||||
|
||||
|
||||
def filter_and_validate_document_id(
|
||||
document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
) -> str:
|
||||
"""
|
||||
Filters and validates a document ID such that it can be used as an ID in
|
||||
OpenSearch.
|
||||
@@ -19,9 +27,13 @@ def filter_and_validate_document_id(document_id: str) -> str:
|
||||
|
||||
Args:
|
||||
document_id: The document ID to filter and validate.
|
||||
max_encoded_length: The maximum length of the document ID after
|
||||
filtering in bytes. Compared with >= for extra resilience, so
|
||||
encoded values of this length will fail.
|
||||
|
||||
Raises:
|
||||
ValueError: If the document ID is empty or too long after filtering.
|
||||
DocumentIDTooLongError: If the document ID is too long after filtering.
|
||||
ValueError: If the document ID is empty after filtering.
|
||||
|
||||
Returns:
|
||||
str: The filtered document ID.
|
||||
@@ -29,6 +41,8 @@ def filter_and_validate_document_id(document_id: str) -> str:
|
||||
filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id)
|
||||
if not filtered_document_id:
|
||||
raise ValueError(f"Document ID {document_id} is empty after filtering.")
|
||||
if len(filtered_document_id.encode("utf-8")) >= 512:
|
||||
raise ValueError(f"Document ID {document_id} is too long after filtering.")
|
||||
if len(filtered_document_id.encode("utf-8")) >= max_encoded_length:
|
||||
raise DocumentIDTooLongError(
|
||||
f"Document ID {document_id} is too long after filtering."
|
||||
)
|
||||
return filtered_document_id
|
||||
|
||||
47
backend/onyx/error_handling/README.md
Normal file
47
backend/onyx/error_handling/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# Error Handling
|
||||
|
||||
This directory is the local source of truth for backend API error handling.
|
||||
|
||||
## Primary Rule
|
||||
|
||||
Raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
|
||||
The global FastAPI exception handler converts `OnyxError` into the standard JSON shape:
|
||||
|
||||
```json
|
||||
{"error_code": "...", "detail": "..."}
|
||||
```
|
||||
|
||||
This keeps API behavior consistent and avoids repetitive route-level boilerplate.
|
||||
|
||||
## Examples
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# Good
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# Good: preserve a dynamic upstream status code
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
```
|
||||
|
||||
Avoid:
|
||||
|
||||
```python
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Available error codes are defined in `backend/onyx/error_handling/error_codes.py`.
|
||||
- If a new error category is needed, add it there first rather than inventing ad hoc strings.
|
||||
- When forwarding upstream service failures with dynamic status codes, use `status_code_override`.
|
||||
@@ -52,9 +52,21 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
|
||||
def get_markitdown_converter() -> "MarkItDown":
|
||||
global _MARKITDOWN_CONVERTER
|
||||
from markitdown import MarkItDown
|
||||
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
# Patch this function to effectively no-op because we were seeing this
|
||||
# module take an inordinate amount of time to convert charts to markdown,
|
||||
# making some powerpoint files with many or complicated charts nearly
|
||||
# unindexable.
|
||||
from markitdown.converters._pptx_converter import PptxConverter
|
||||
|
||||
setattr(
|
||||
PptxConverter,
|
||||
"_convert_chart_to_markdown",
|
||||
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
|
||||
)
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
@@ -205,18 +217,26 @@ def read_pdf_file(
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
if pdf_reader.is_encrypted:
|
||||
# Try the explicit password first, then fall back to an empty
|
||||
# string. Owner-password-only PDFs (permission restrictions but
|
||||
# no open password) decrypt successfully with "".
|
||||
# See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
passwords = [p for p in [pdf_pass, ""] if p is not None]
|
||||
decrypt_success = False
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
for pw in passwords:
|
||||
try:
|
||||
if pdf_reader.decrypt(pw) != 0:
|
||||
decrypt_success = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not decrypt_success:
|
||||
logger.error(
|
||||
"Encrypted PDF could not be decrypted, returning empty text."
|
||||
)
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
|
||||
@@ -33,8 +33,20 @@ def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
if not reader.is_encrypted:
|
||||
return False
|
||||
|
||||
return bool(reader.is_encrypted)
|
||||
# PDFs with only an owner password (permission restrictions like
|
||||
# print/copy disabled) use an empty user password — any viewer can open
|
||||
# them without prompting. decrypt("") returns 0 only when a real user
|
||||
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
try:
|
||||
return reader.decrypt("") == 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate PDF encryption; treating as password protected"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
|
||||
@@ -136,12 +136,14 @@ class FileStore(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
def delete_file(self, file_id: str, error_on_missing: bool = True) -> None:
|
||||
"""
|
||||
Delete a file by its ID.
|
||||
|
||||
Parameters:
|
||||
- file_name: Name of file to delete
|
||||
- file_id: ID of file to delete
|
||||
- error_on_missing: If False, silently return when the file record
|
||||
does not exist instead of raising.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -452,12 +454,23 @@ class S3BackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
if file_record is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File by id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
if not file_record.bucket_name:
|
||||
logger.error(
|
||||
f"File record {file_id} with key {file_record.object_key} "
|
||||
|
||||
@@ -222,12 +222,23 @@ class PostgresBackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as session:
|
||||
try:
|
||||
file_content = get_file_content_by_file_id(
|
||||
file_content = get_file_content_by_file_id_optional(
|
||||
file_id=file_id, db_session=session
|
||||
)
|
||||
if file_content is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File content for file_id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
raw_conn = _get_raw_connection(session)
|
||||
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,7 @@ class LlmProviderNames(str, Enum):
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -46,6 +47,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
]
|
||||
|
||||
|
||||
@@ -64,6 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI-Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -84,6 +87,44 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"gemini": "Gemini",
|
||||
"stability": "Stability",
|
||||
"writer": "Writer",
|
||||
# Custom provider display names (used in the custom provider picker)
|
||||
"aiml": "AI/ML",
|
||||
"assemblyai": "AssemblyAI",
|
||||
"aws_polly": "AWS Polly",
|
||||
"azure_ai": "Azure AI",
|
||||
"chatgpt": "ChatGPT",
|
||||
"cohere_chat": "Cohere Chat",
|
||||
"datarobot": "DataRobot",
|
||||
"deepgram": "Deepgram",
|
||||
"deepinfra": "DeepInfra",
|
||||
"elevenlabs": "ElevenLabs",
|
||||
"fal_ai": "fal.ai",
|
||||
"featherless_ai": "Featherless AI",
|
||||
"fireworks_ai": "Fireworks AI",
|
||||
"friendliai": "FriendliAI",
|
||||
"gigachat": "GigaChat",
|
||||
"github_copilot": "GitHub Copilot",
|
||||
"gradient_ai": "Gradient AI",
|
||||
"huggingface": "HuggingFace",
|
||||
"jina_ai": "Jina AI",
|
||||
"lambda_ai": "Lambda AI",
|
||||
"llamagate": "LlamaGate",
|
||||
"meta_llama": "Meta Llama",
|
||||
"minimax": "MiniMax",
|
||||
"nlp_cloud": "NLP Cloud",
|
||||
"nvidia_nim": "NVIDIA NIM",
|
||||
"oci": "OCI",
|
||||
"ovhcloud": "OVHcloud",
|
||||
"palm": "PaLM",
|
||||
"publicai": "PublicAI",
|
||||
"runwayml": "RunwayML",
|
||||
"sambanova": "SambaNova",
|
||||
"together_ai": "Together AI",
|
||||
"vercel_ai_gateway": "Vercel AI Gateway",
|
||||
"volcengine": "Volcengine",
|
||||
"wandb": "W&B",
|
||||
"watsonx": "IBM watsonx",
|
||||
"zai": "ZAI",
|
||||
}
|
||||
|
||||
# Map vendors to their brand names (used for provider_display_name generation)
|
||||
@@ -116,6 +157,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -327,12 +327,19 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
|
||||
# model names directly to the endpoint. We route through LiteLLM's
|
||||
# openai provider with the server's base URL, and ensure /v1 is appended.
|
||||
if model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
):
|
||||
self._custom_llm_provider = "openai"
|
||||
# LiteLLM's OpenAI client requires an api_key to be set.
|
||||
# Many OpenAI-compatible servers don't need auth, so supply a
|
||||
# placeholder to prevent LiteLLM from raising AuthenticationError.
|
||||
if not self._api_key:
|
||||
model_kwargs.setdefault("api_key", "not-needed")
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
@@ -449,17 +456,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
is_openai_compatible_proxy = self._model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
)
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
if is_openai_compatible_proxy:
|
||||
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
|
||||
# servers) expect model names sent directly to their endpoint.
|
||||
# We use custom_llm_provider="openai" so LiteLLM doesn't try
|
||||
# to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
@@ -550,7 +560,10 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
if (
|
||||
not (is_claude_model or is_ollama or is_mistral)
|
||||
or is_openai_compatible_proxy
|
||||
):
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
|
||||
@@ -15,6 +15,8 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
@@ -51,6 +52,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
|
||||
}
|
||||
|
||||
|
||||
@@ -336,6 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI-Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
@@ -15,6 +17,21 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
The backend returns OnyxError responses as
|
||||
``{"error_code": "...", "detail": "..."}``.
|
||||
"""
|
||||
try:
|
||||
body = response.json()
|
||||
if detail := body.get("detail"):
|
||||
return str(detail)
|
||||
except Exception:
|
||||
pass
|
||||
return f"Request failed with status {response.status_code}"
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
@@ -158,7 +175,14 @@ async def search_indexed_documents(
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": error_detail,
|
||||
}
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
@@ -234,7 +258,13 @@ async def search_web(
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
@@ -280,7 +310,12 @@ async def open_urls(
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -16,6 +17,7 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -90,6 +90,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.manage.models import SlackBotTokens
|
||||
from onyx.tracing.setup import setup_tracing
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -1206,6 +1207,7 @@ if __name__ == "__main__":
|
||||
tenant_handler = SlackbotHandler()
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
setup_tracing()
|
||||
|
||||
try:
|
||||
# Keep the main thread alive
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.api_key import ApiKeyDescriptor
|
||||
from onyx.db.api_key import fetch_api_keys
|
||||
from onyx.db.api_key import insert_api_key
|
||||
@@ -10,6 +10,7 @@ from onyx.db.api_key import regenerate_api_key
|
||||
from onyx.db.api_key import remove_api_key
|
||||
from onyx.db.api_key import update_api_key
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
|
||||
@@ -19,7 +20,7 @@ router = APIRouter(prefix="/admin/api-key")
|
||||
|
||||
@router.get("")
|
||||
def list_api_keys(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ApiKeyDescriptor]:
|
||||
return fetch_api_keys(db_session)
|
||||
@@ -28,7 +29,7 @@ def list_api_keys(
|
||||
@router.post("")
|
||||
def create_api_key(
|
||||
api_key_args: APIKeyArgs,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return insert_api_key(db_session, api_key_args, user.id)
|
||||
@@ -37,7 +38,7 @@ def create_api_key(
|
||||
@router.post("/{api_key_id}/regenerate")
|
||||
def regenerate_existing_api_key(
|
||||
api_key_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return regenerate_api_key(db_session, api_key_id)
|
||||
@@ -47,7 +48,7 @@ def regenerate_existing_api_key(
|
||||
def update_existing_api_key(
|
||||
api_key_id: int,
|
||||
api_key_args: APIKeyArgs,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return update_api_key(db_session, api_key_id, api_key_args)
|
||||
@@ -56,7 +57,7 @@ def update_existing_api_key(
|
||||
@router.delete("/{api_key_id}")
|
||||
def delete_api_key(
|
||||
api_key_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
remove_api_key(db_session, api_key_id)
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi import FastAPI
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
@@ -91,6 +90,11 @@ def is_route_in_spec_list(
|
||||
return False
|
||||
|
||||
|
||||
def _is_require_permission_dependency(fn: object) -> bool:
|
||||
"""Detect closures generated by ``require_permission()``."""
|
||||
return bool(getattr(fn, "_is_require_permission", False))
|
||||
|
||||
|
||||
def check_router_auth(
|
||||
application: FastAPI,
|
||||
public_endpoint_specs: list[tuple[str, set[str]]] = PUBLIC_ENDPOINT_SPECS,
|
||||
@@ -126,7 +130,6 @@ def check_router_auth(
|
||||
if (
|
||||
depends_fn == current_limited_user
|
||||
or depends_fn == current_user
|
||||
or depends_fn == current_admin_user
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
@@ -134,6 +137,7 @@ def check_router_auth(
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
or _is_require_permission_dependency(depends_fn)
|
||||
):
|
||||
found_auth = True
|
||||
break
|
||||
|
||||
@@ -10,8 +10,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
@@ -38,6 +38,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.enums import PermissionSyncStatus
|
||||
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_cc_pair
|
||||
@@ -622,7 +623,7 @@ def associate_credential_to_connector(
|
||||
def dissociate_credential_from_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
return remove_credential_from_connector(
|
||||
|
||||
@@ -22,10 +22,9 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.email_utils import send_email
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
@@ -105,6 +104,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.federated import fetch_all_federated_connectors_parallel
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
@@ -189,7 +189,8 @@ def check_google_app_gmail_credentials_exist(
|
||||
|
||||
@router.put("/admin/connector/gmail/app-credential")
|
||||
def upsert_google_app_gmail_credentials(
|
||||
app_credentials: GoogleAppCredentials, _: User = Depends(current_admin_user)
|
||||
app_credentials: GoogleAppCredentials,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
upsert_google_app_cred(app_credentials, DocumentSource.GMAIL)
|
||||
@@ -203,7 +204,7 @@ def upsert_google_app_gmail_credentials(
|
||||
|
||||
@router.delete("/admin/connector/gmail/app-credential")
|
||||
def delete_google_app_gmail_credentials(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
@@ -231,7 +232,8 @@ def check_google_app_credentials_exist(
|
||||
|
||||
@router.put("/admin/connector/google-drive/app-credential")
|
||||
def upsert_google_app_credentials(
|
||||
app_credentials: GoogleAppCredentials, _: User = Depends(current_admin_user)
|
||||
app_credentials: GoogleAppCredentials,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
upsert_google_app_cred(app_credentials, DocumentSource.GOOGLE_DRIVE)
|
||||
@@ -245,7 +247,7 @@ def upsert_google_app_credentials(
|
||||
|
||||
@router.delete("/admin/connector/google-drive/app-credential")
|
||||
def delete_google_app_credentials(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
@@ -277,7 +279,8 @@ def check_google_service_gmail_account_key_exist(
|
||||
|
||||
@router.put("/admin/connector/gmail/service-account-key")
|
||||
def upsert_google_service_gmail_account_key(
|
||||
service_account_key: GoogleServiceAccountKey, _: User = Depends(current_admin_user)
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
upsert_service_account_key(service_account_key, DocumentSource.GMAIL)
|
||||
@@ -291,7 +294,7 @@ def upsert_google_service_gmail_account_key(
|
||||
|
||||
@router.delete("/admin/connector/gmail/service-account-key")
|
||||
def delete_google_service_gmail_account_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
@@ -323,7 +326,8 @@ def check_google_service_account_key_exist(
|
||||
|
||||
@router.put("/admin/connector/google-drive/service-account-key")
|
||||
def upsert_google_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey, _: User = Depends(current_admin_user)
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
upsert_service_account_key(service_account_key, DocumentSource.GOOGLE_DRIVE)
|
||||
@@ -337,7 +341,7 @@ def upsert_google_service_account_key(
|
||||
|
||||
@router.delete("/admin/connector/google-drive/service-account-key")
|
||||
def delete_google_service_account_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
@@ -407,7 +411,7 @@ def upsert_gmail_service_account_credential(
|
||||
@router.get("/admin/connector/google-drive/check-auth/{credential_id}")
|
||||
def check_drive_tokens(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AuthStatus:
|
||||
db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
@@ -1794,7 +1798,9 @@ def connector_run_once(
|
||||
|
||||
@router.get("/connector/gmail/authorize/{credential_id}")
|
||||
def gmail_auth(
|
||||
response: Response, credential_id: str, _: User = Depends(current_user)
|
||||
response: Response,
|
||||
credential_id: str,
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> AuthUrl:
|
||||
# set a cookie that we can read in the callback (used for `verify_csrf`)
|
||||
response.set_cookie(
|
||||
@@ -1808,7 +1814,9 @@ def gmail_auth(
|
||||
|
||||
@router.get("/connector/google-drive/authorize/{credential_id}")
|
||||
def google_drive_auth(
|
||||
response: Response, credential_id: str, _: User = Depends(current_user)
|
||||
response: Response,
|
||||
credential_id: str,
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> AuthUrl:
|
||||
# set a cookie that we can read in the callback (used for `verify_csrf`)
|
||||
response.set_cookie(
|
||||
@@ -1826,7 +1834,7 @@ def google_drive_auth(
|
||||
def gmail_callback(
|
||||
request: Request,
|
||||
callback: GmailCallback = Depends(),
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
credential_id_cookie = request.cookies.get(_GMAIL_CREDENTIAL_ID_COOKIE_NAME)
|
||||
@@ -1856,7 +1864,7 @@ def gmail_callback(
|
||||
def google_drive_callback(
|
||||
request: Request,
|
||||
callback: GDriveCallback = Depends(),
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
credential_id_cookie = request.cookies.get(_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME)
|
||||
@@ -1885,7 +1893,7 @@ def google_drive_callback(
|
||||
|
||||
@router.get("/connector", tags=PUBLIC_API_TAGS)
|
||||
def get_connectors(
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ConnectorSnapshot]:
|
||||
connectors = fetch_connectors(db_session)
|
||||
@@ -1900,7 +1908,7 @@ def get_connectors(
|
||||
|
||||
@router.get("/indexed-sources", tags=PUBLIC_API_TAGS)
|
||||
def get_indexed_sources(
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> IndexedSourcesResponse:
|
||||
sources = sorted(
|
||||
@@ -1912,7 +1920,7 @@ def get_indexed_sources(
|
||||
@router.get("/connector/{connector_id}", tags=PUBLIC_API_TAGS)
|
||||
def get_connector_by_id(
|
||||
connector_id: int,
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorSnapshot | StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
@@ -1941,7 +1949,7 @@ def get_connector_by_id(
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user