mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-22 01:46:47 +00:00
Compare commits
130 Commits
new-permis
...
argus
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4feca94f08 | ||
|
|
e9538c2a8f | ||
|
|
bc39465a5c | ||
|
|
48463a353d | ||
|
|
0c61b3bb97 | ||
|
|
0948f58fa0 | ||
|
|
ac448cf3c5 | ||
|
|
f7771847fb | ||
|
|
c61adc6560 | ||
|
|
4b0cb5b9c3 | ||
|
|
21293f6621 | ||
|
|
0f31c490fa | ||
|
|
c9a4a6e42b | ||
|
|
558c9df3c7 | ||
|
|
30003036d3 | ||
|
|
4b2f18c239 | ||
|
|
4290b097f5 | ||
|
|
b0f621a08b | ||
|
|
112edf41c5 | ||
|
|
74eb1d7212 | ||
|
|
e62d592b11 | ||
|
|
57a0d25321 | ||
|
|
887f79d7a5 | ||
|
|
65fd1c3ec8 | ||
|
|
6e3ee287b9 | ||
|
|
dee0b7867e | ||
|
|
77beb8044e | ||
|
|
750d3ac4ed | ||
|
|
6c02087ba4 | ||
|
|
0425283ed0 | ||
|
|
da97a57c58 | ||
|
|
8087ddb97c | ||
|
|
d9d5943dc4 | ||
|
|
97a7fa6f7f | ||
|
|
8027e62446 | ||
|
|
571e860d4f | ||
|
|
89b91ac384 | ||
|
|
069b1f3efb | ||
|
|
ef2fffcd6e | ||
|
|
925be18424 | ||
|
|
38fffc8ad8 | ||
|
|
3e9e2f08d5 | ||
|
|
243d93ecd8 | ||
|
|
4effe77225 | ||
|
|
ef2df458a3 | ||
|
|
d3000da3d0 | ||
|
|
a5c703f9ca | ||
|
|
d10c901c43 | ||
|
|
f1ac555c57 | ||
|
|
ed52384c21 | ||
|
|
cb10376a0d | ||
|
|
5a25b70b9c | ||
|
|
8cbc37f281 | ||
|
|
9d78f71f23 | ||
|
|
fbf3179d84 | ||
|
|
779470b553 | ||
|
|
151e189898 | ||
|
|
72e08f81a4 | ||
|
|
65792a8ad8 | ||
|
|
497b700b3d | ||
|
|
c3ed2135f1 | ||
|
|
a969d56818 | ||
|
|
a31d862f48 | ||
|
|
a4e6d4cf43 | ||
|
|
1e6f94e00d | ||
|
|
a769b87a9d | ||
|
|
278fc7e9b1 | ||
|
|
eb34df470f | ||
|
|
9d1785273f | ||
|
|
ef69b17d26 | ||
|
|
787c961802 | ||
|
|
62bc4fa2a3 | ||
|
|
bb1c44daff | ||
|
|
f26ecafb51 | ||
|
|
9fdb425c0d | ||
|
|
47e20e89c5 | ||
|
|
8b28c127f2 | ||
|
|
9a861a71ad | ||
|
|
b4bc12f6dc | ||
|
|
9af9148ca7 | ||
|
|
8a517c4f10 | ||
|
|
6959d851ea | ||
|
|
6a2550fc2d | ||
|
|
b1cc0c2bf9 | ||
|
|
c28b17064b | ||
|
|
4dab92ab52 | ||
|
|
7eb68d61b0 | ||
|
|
8c7810d688 | ||
|
|
712e6fdf5e | ||
|
|
f1a9a3b41e | ||
|
|
c3405fb6bf | ||
|
|
3e962935f4 | ||
|
|
0aa1aa7ea0 | ||
|
|
771d2cf101 | ||
|
|
7ec50280ed | ||
|
|
5b2ba5caeb | ||
|
|
4a96ef13d7 | ||
|
|
822b0c99be | ||
|
|
bcf2851a85 | ||
|
|
a5a59bd8f0 | ||
|
|
32d2e7985a | ||
|
|
c4f8d5370b | ||
|
|
9e434f6a5a | ||
|
|
67dc819319 | ||
|
|
2d12274050 | ||
|
|
c727ba13ee | ||
|
|
6193dd5326 | ||
|
|
387a7d1cea | ||
|
|
869578eeed | ||
|
|
e68648ab74 | ||
|
|
da01002099 | ||
|
|
f5d66f389c | ||
|
|
82d89f78c6 | ||
|
|
6f49c5e32c | ||
|
|
41f2bd2f19 | ||
|
|
bfa2f672f9 | ||
|
|
a823c3ead1 | ||
|
|
bd7d378a9a | ||
|
|
dcec0c8ef3 | ||
|
|
6456b51dcf | ||
|
|
7cfe27e31e | ||
|
|
3c5f77f5a4 | ||
|
|
ab4d1dce01 | ||
|
|
80c928eb58 | ||
|
|
77528876b1 | ||
|
|
3bf53495f3 | ||
|
|
e4cfcda0bf | ||
|
|
475e8f6cdc | ||
|
|
945272c1d2 | ||
|
|
185b057483 |
62
.devcontainer/Dockerfile
Normal file
62
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,62 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
86
.devcontainer/README.md
Normal file
86
.devcontainer/README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure the active
|
||||
user has read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
|
||||
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
|
||||
maps back to your host user via the user namespace. New files are owned by your
|
||||
host UID and no ACL workarounds are needed.
|
||||
|
||||
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
|
||||
`ods dev up`.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
23
.devcontainer/devcontainer.json
Normal file
23
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
107
.devcontainer/init-dev-user.sh
Normal file
107
.devcontainer/init-dev-user.sh
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. Requires
|
||||
# DEVCONTAINER_REMOTE_USER=root (set automatically by
|
||||
# ods dev up). Container root IS the host user, so
|
||||
# bind-mounts and named volumes are symlinked into /root.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
|
||||
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
|
||||
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
|
||||
MOUNT_HOME=/home/"$TARGET_USER"
|
||||
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
ACTIVE_HOME="/root"
|
||||
else
|
||||
ACTIVE_HOME="$MOUNT_HOME"
|
||||
fi
|
||||
|
||||
# ── Phase 1: home directory setup ───────────────────────────────────
|
||||
|
||||
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
|
||||
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
|
||||
|
||||
# When running as root, symlink bind-mounts and named volumes into /root
|
||||
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
|
||||
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
for item in .claude .cache .local; do
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
# Symlink files (not directories).
|
||||
for file in .claude.json .gitconfig .zshrc.host; do
|
||||
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
|
||||
done
|
||||
|
||||
# Nested mount: .config/nvim
|
||||
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
|
||||
mkdir -p "$ACTIVE_HOME/.config"
|
||||
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
|
||||
rm -rf "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Phase 2: workspace access ───────────────────────────────────────
|
||||
|
||||
# Root always has workspace access; Phase 1 handled home setup.
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
|
||||
echo "warning: failed to chown $MOUNT_HOME" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned (UID 0) due to user-namespace mapping.
|
||||
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
|
||||
# which is handled above. If we reach here, the user is running as dev
|
||||
# under rootless Docker without the override.
|
||||
echo "error: rootless Docker detected but remoteUser is not root." >&2
|
||||
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
|
||||
echo " or use 'ods dev up' which sets it automatically." >&2
|
||||
exit 1
|
||||
fi
|
||||
105
.devcontainer/init-firewall.sh
Executable file
105
.devcontainer/init-firewall.sh
Executable file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Preserve docker dns resolution
|
||||
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
|
||||
|
||||
# Flush all rules
|
||||
iptables -t nat -F
|
||||
iptables -t nat -X
|
||||
iptables -t mangle -F
|
||||
iptables -t mangle -X
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Restore docker dns rules
|
||||
if [ -n "$DOCKER_DNS_RULES" ]; then
|
||||
echo "$DOCKER_DNS_RULES" | iptables-restore -n
|
||||
fi
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Allow traffic to the Docker gateway so the container can reach host services
|
||||
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
|
||||
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
|
||||
if [ -n "$DOCKER_GATEWAY" ]; then
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
10
.devcontainer/zshrc
Normal file
10
.devcontainer/zshrc
Normal file
@@ -0,0 +1,10 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
8
.github/workflows/deployment.yml
vendored
8
.github/workflows/deployment.yml
vendored
@@ -13,7 +13,7 @@ permissions:
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') || github.ref_name == 'edge' }}
|
||||
|
||||
jobs:
|
||||
# Determine which components to build based on the tag
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -156,7 +156,7 @@ jobs:
|
||||
check-version-tag:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.ref_name != 'edge' && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -1,64 +1,57 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": ["logic", "syntax", "style"],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": false,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ repos:
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
@@ -18,7 +17,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
@@ -31,7 +30,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
@@ -44,7 +43,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
@@ -57,7 +56,7 @@ repos:
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"--group",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -531,8 +531,7 @@
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"sync",
|
||||
"--all-extras"
|
||||
"sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
|
||||
@@ -117,7 +117,7 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
uv sync
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
@@ -208,7 +208,7 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
batch_op.alter_column(
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -0,0 +1,499 @@
|
||||
"""add proposal review tables
|
||||
|
||||
Revision ID: 61ea78857c97
|
||||
Revises: d129f37b3d87
|
||||
Create Date: 2026-04-09 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ea78857c97"
|
||||
down_revision = "d129f37b3d87"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# -- proposal_review_ruleset --
|
||||
op.create_table(
|
||||
"proposal_review_ruleset",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["created_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_ruleset_tenant_id",
|
||||
"proposal_review_ruleset",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_rule --
|
||||
op.create_table(
|
||||
"proposal_review_rule",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("category", sa.Text(), nullable=True),
|
||||
sa.Column("rule_type", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"rule_intent",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'CHECK'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("prompt_template", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'MANUAL'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("authority", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"is_hard_stop",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"priority",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"refinement_needed",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("refinement_question", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_rule_ruleset_id",
|
||||
"proposal_review_rule",
|
||||
["ruleset_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_proposal --
|
||||
# Includes inline proposal-level decision fields (no separate decision table).
|
||||
op.create_table(
|
||||
"proposal_review_proposal",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("document_id", sa.Text(), nullable=False),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
# Inline proposal-level decision fields
|
||||
sa.Column("decision_notes", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"decision_officer_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("decision_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"jira_synced",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("jira_synced_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("document_id", "tenant_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_tenant_id",
|
||||
"proposal_review_proposal",
|
||||
["tenant_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_document_id",
|
||||
"proposal_review_proposal",
|
||||
["document_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_status",
|
||||
"proposal_review_proposal",
|
||||
["status"],
|
||||
)
|
||||
|
||||
# -- proposal_review_run --
|
||||
op.create_table(
|
||||
"proposal_review_run",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"triggered_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("total_rules", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"completed_rules",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["triggered_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_run_proposal_id",
|
||||
"proposal_review_run",
|
||||
["proposal_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_finding --
|
||||
# Includes inline per-finding decision fields (no separate decision table).
|
||||
op.create_table(
|
||||
"proposal_review_finding",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"rule_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"review_run_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("verdict", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Text(), nullable=True),
|
||||
sa.Column("evidence", sa.Text(), nullable=True),
|
||||
sa.Column("explanation", sa.Text(), nullable=True),
|
||||
sa.Column("suggested_action", sa.Text(), nullable=True),
|
||||
sa.Column("llm_model", sa.Text(), nullable=True),
|
||||
sa.Column("llm_tokens_used", sa.Integer(), nullable=True),
|
||||
# Inline per-finding decision fields
|
||||
sa.Column("decision_action", sa.Text(), nullable=True),
|
||||
sa.Column("decision_notes", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"decision_officer_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("decided_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["rule_id"],
|
||||
["proposal_review_rule.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["review_run_id"],
|
||||
["proposal_review_run.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_proposal_id",
|
||||
"proposal_review_finding",
|
||||
["proposal_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_review_run_id",
|
||||
"proposal_review_finding",
|
||||
["review_run_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_rule_id",
|
||||
"proposal_review_finding",
|
||||
["rule_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_document --
|
||||
op.create_table(
|
||||
"proposal_review_document",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("file_name", sa.Text(), nullable=False),
|
||||
sa.Column("file_type", sa.Text(), nullable=True),
|
||||
sa.Column("file_store_id", sa.Text(), nullable=True),
|
||||
sa.Column("extracted_text", sa.Text(), nullable=True),
|
||||
sa.Column("document_role", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"uploaded_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["uploaded_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_document_proposal_id",
|
||||
"proposal_review_document",
|
||||
["proposal_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_import_job --
|
||||
op.create_table(
|
||||
"proposal_review_import_job",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_filename", sa.Text(), nullable=False),
|
||||
sa.Column("extracted_text", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"rules_created",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_import_job_ruleset_id",
|
||||
"proposal_review_import_job",
|
||||
["ruleset_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_config --
|
||||
op.create_table(
|
||||
"proposal_review_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False, unique=True),
|
||||
sa.Column("jira_connector_id", sa.Integer(), nullable=True),
|
||||
sa.Column("jira_project_key", sa.Text(), nullable=True),
|
||||
sa.Column("field_mapping", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("jira_writeback", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("review_model", sa.Text(), nullable=True),
|
||||
sa.Column("import_model", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("proposal_review_import_job")
|
||||
op.drop_table("proposal_review_config")
|
||||
op.drop_table("proposal_review_document")
|
||||
op.drop_table("proposal_review_finding")
|
||||
op.drop_table("proposal_review_run")
|
||||
op.drop_table("proposal_review_proposal")
|
||||
op.drop_table("proposal_review_rule")
|
||||
op.drop_table("proposal_review_ruleset")
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
existing_server_default=sa.text("now()"),
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
"""make user role nullable
|
||||
|
||||
The ``user.role`` column is no longer written or read by application
|
||||
code — admin status is derived from group membership and classification
|
||||
lives on ``user.account_type``. Relax the NOT NULL constraint so inserts
|
||||
that omit the column (which is now the default path after the write-path
|
||||
cleanup) succeed. The column itself is kept as a tombstone for rollback
|
||||
safety and will be dropped in a follow-up migration once the new model
|
||||
has been in production for a release cycle.
|
||||
|
||||
Revision ID: c8e316473aaa
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-14 14:57:29.520645
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c8e316473aaa"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"user",
|
||||
"role",
|
||||
existing_type=sa.VARCHAR(length=14),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Backfill any NULLs written while the column was optional before we
|
||||
# restore the NOT NULL constraint, otherwise the downgrade would fail
|
||||
# against rows inserted after the upgrade.
|
||||
op.execute("UPDATE \"user\" SET role = 'BASIC' WHERE role IS NULL")
|
||||
op.alter_column(
|
||||
"user",
|
||||
"role",
|
||||
existing_type=sa.VARCHAR(length=14),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add failed_rules to proposal_review_run
|
||||
|
||||
Revision ID: ce2aa573d445
|
||||
Revises: 61ea78857c97
|
||||
Create Date: 2026-04-14 16:34:57.276707
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ce2aa573d445"
|
||||
down_revision = "61ea78857c97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"proposal_review_run",
|
||||
sa.Column(
|
||||
"failed_rules",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("proposal_review_run", "failed_rules")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add_error_tracking_fields_to_index_attempt_errors
|
||||
|
||||
Revision ID: d129f37b3d87
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-06 19:11:18.261800
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d129f37b3d87"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt_errors",
|
||||
sa.Column("error_type", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt_errors", "error_type")
|
||||
@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
target_metadata=target_metadata,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -11,14 +11,13 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
def fetch_query_analytics(
|
||||
@@ -339,7 +338,7 @@ def fetch_assistant_unique_users_total(
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User, assistant_id: int
|
||||
) -> bool:
|
||||
if has_permission(user, Permission.FULL_ADMIN_PANEL_ACCESS):
|
||||
if user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
@@ -107,12 +108,13 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
For self-hosted: counts all active users.
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -127,8 +129,9 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.account_type != AccountType.EXT_PERM_USER,
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -1,16 +1,74 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import Row
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
from onyx.db.models import TokenRateLimit
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
|
||||
|
||||
def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select:
|
||||
if user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
# If anonymous user, only show global/public token_rate_limits
|
||||
if user.is_anonymous:
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
stmt = stmt.distinct()
|
||||
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
|
||||
"""
|
||||
Here we select token_rate_limits by relation:
|
||||
User -> User__UserGroup -> TokenRateLimit__UserGroup ->
|
||||
TokenRateLimit
|
||||
"""
|
||||
stmt = stmt.outerjoin(TRLimit_UG).outerjoin(
|
||||
User__UG,
|
||||
User__UG.user_group_id == TRLimit_UG.user_group_id,
|
||||
)
|
||||
|
||||
"""
|
||||
Filter token_rate_limits by:
|
||||
- if the user is in the user_group that owns the token_rate_limit
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out token_rate_limits that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all token_rate_limits in the groups the user curates
|
||||
"""
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(TRLimit_UG.rate_limit_id == TokenRateLimit.id)
|
||||
.where(~TRLimit_UG.user_group_id.in_(user_groups))
|
||||
.correlate(TokenRateLimit)
|
||||
)
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def fetch_all_user_group_token_rate_limits_by_group(
|
||||
db_session: Session,
|
||||
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
|
||||
@@ -49,11 +107,13 @@ def insert_user_group_token_rate_limit(
|
||||
return token_limit
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits_for_group(
|
||||
def fetch_user_group_token_rate_limits_for_user(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
@@ -63,6 +123,7 @@ def fetch_user_group_token_rate_limits_for_group(
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
@@ -2,14 +2,17 @@ from collections.abc import Sequence
|
||||
from operator import and_
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
@@ -35,6 +38,7 @@ from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
@@ -130,6 +134,73 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
object_is_owned_by_user: bool = False,
|
||||
object_is_new: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
All users can create/edit permission synced objects if they don't specify a group
|
||||
All admin actions are allowed.
|
||||
Curators and global curators can create public objects.
|
||||
Prevents other non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
"""
|
||||
if object_is_perm_sync and not target_group_ids:
|
||||
return
|
||||
|
||||
# Admins are allowed
|
||||
if user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# Allow curators and global curators to create public objects
|
||||
# w/o associated groups IF the object is new/owned by them
|
||||
if (
|
||||
object_is_public
|
||||
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
|
||||
and (object_is_new or object_is_owned_by_user)
|
||||
):
|
||||
return
|
||||
|
||||
if object_is_public and user.role == UserRole.BASIC:
|
||||
detail = "User does not have permission to create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
if not target_group_ids:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user_curated_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
# Global curators can curate all groups they are member of
|
||||
only_curator_groups=user.role != UserRole.GLOBAL_CURATOR,
|
||||
)
|
||||
user_curated_group_ids = set([group.id for group in user_curated_groups])
|
||||
target_group_ids_set = set(target_group_ids)
|
||||
if not target_group_ids_set.issubset(user_curated_group_ids):
|
||||
detail = "Curators cannot control groups they don't curate"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
return db_session.scalar(stmt)
|
||||
@@ -222,6 +293,7 @@ def fetch_user_groups(
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
@@ -231,6 +303,8 @@ def fetch_user_groups_for_user(
|
||||
.join(User, User.id == User__UserGroup.user_id) # type: ignore
|
||||
.where(User.id == user_id) # type: ignore
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
@@ -456,6 +530,167 @@ def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
user_group__cc_pair_relationship.is_current = False
|
||||
|
||||
|
||||
def _validate_curator_status__no_commit(
|
||||
db_session: Session,
|
||||
users: list[User],
|
||||
) -> None:
|
||||
for user in users:
|
||||
# Check if the user is a curator in any of their groups
|
||||
curator_relationships = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.is_curator == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC only if they were previously a CURATOR
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
user.role = UserRole.BASIC
|
||||
db_session.add(user)
|
||||
|
||||
|
||||
def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
stmt = (
|
||||
update(User__UserGroup)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.values(is_curator=False)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_making_change: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
# Admins can update curator relationships for any group
|
||||
if user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_group_id == user_group_id,
|
||||
User__UserGroup.user_id == set_curator_request.user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if relationship_to_update:
|
||||
relationship_to_update.is_curator = set_curator_request.is_curator
|
||||
else:
|
||||
relationship_to_update = User__UserGroup(
|
||||
user_group_id=user_group_id,
|
||||
user_id=set_curator_request.user_id,
|
||||
is_curator=True,
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_users_to_user_group(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
@@ -531,6 +766,13 @@ def update_user_group(
|
||||
f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}"
|
||||
)
|
||||
|
||||
# LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES
|
||||
# ACCESS TO DIFFERENT PERMISSIONS
|
||||
# if (removed_user_ids or added_user_ids) and (
|
||||
# not user or user.role != UserRole.ADMIN
|
||||
# ):
|
||||
# raise ValueError("Only admins can add or remove users from user groups")
|
||||
|
||||
if removed_user_ids:
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
@@ -561,6 +803,20 @@ def update_user_group(
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
users_to_validate = [
|
||||
user
|
||||
for user in removed_users
|
||||
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
|
||||
]
|
||||
|
||||
if users_to_validate:
|
||||
_validate_curator_status__no_commit(db_session, users_to_validate)
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
@@ -740,72 +996,3 @@ def set_group_permission__no_commit(
|
||||
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group_id, db_session)
|
||||
|
||||
|
||||
def set_group_permissions_bulk__no_commit(
|
||||
group_id: int,
|
||||
desired_permissions: set[Permission],
|
||||
granted_by: UUID,
|
||||
db_session: Session,
|
||||
) -> list[Permission]:
|
||||
"""Set the full desired permission state for a group in one pass.
|
||||
|
||||
Enables permissions in `desired_permissions`, disables any toggleable
|
||||
permission not in the set. Non-toggleable permissions are ignored.
|
||||
Calls recompute once at the end. Does NOT commit.
|
||||
|
||||
Returns the resulting list of enabled permissions.
|
||||
"""
|
||||
|
||||
existing_grants = (
|
||||
db_session.execute(
|
||||
select(PermissionGrant)
|
||||
.where(PermissionGrant.group_id == group_id)
|
||||
.with_for_update()
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
grant_map: dict[Permission, PermissionGrant] = {
|
||||
g.permission: g for g in existing_grants
|
||||
}
|
||||
|
||||
# Enable desired permissions
|
||||
for perm in desired_permissions:
|
||||
existing = grant_map.get(perm)
|
||||
if existing is not None:
|
||||
if existing.is_deleted:
|
||||
existing.is_deleted = False
|
||||
existing.granted_by = granted_by
|
||||
existing.granted_at = func.now()
|
||||
else:
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=perm,
|
||||
grant_source=GrantSource.USER,
|
||||
granted_by=granted_by,
|
||||
)
|
||||
)
|
||||
|
||||
# Disable toggleable permissions not in the desired set
|
||||
for perm, grant in grant_map.items():
|
||||
if perm not in desired_permissions and not grant.is_deleted:
|
||||
grant.is_deleted = True
|
||||
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group_id, db_session)
|
||||
|
||||
# Return the resulting enabled set
|
||||
return [
|
||||
g.permission
|
||||
for g in db_session.execute(
|
||||
select(PermissionGrant).where(
|
||||
PermissionGrant.group_id == group_id,
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
@@ -10,16 +12,13 @@ from ee.onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
from ee.onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
try_creating_external_group_sync_task,
|
||||
)
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.models import StatusResponse
|
||||
@@ -33,7 +32,7 @@ router = APIRouter(prefix="/manage")
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(require_permission(Permission.READ_CONNECTORS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
@@ -43,9 +42,9 @@ def get_cc_pair_latest_sync(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"CC Pair not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_perm_sync
|
||||
@@ -54,7 +53,7 @@ def get_cc_pair_latest_sync(
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[None]:
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
@@ -67,18 +66,18 @@ def sync_cc_pair(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"Connection not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
"Permissions sync task already in progress.",
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -91,9 +90,9 @@ def sync_cc_pair(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Permissions sync task creation failed.",
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
@@ -107,7 +106,7 @@ def sync_cc_pair(
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def get_cc_pair_latest_group_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(require_permission(Permission.READ_CONNECTORS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
@@ -117,9 +116,9 @@ def get_cc_pair_latest_group_sync(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"CC Pair not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_external_group_sync
|
||||
@@ -128,7 +127,7 @@ def get_cc_pair_latest_group_sync(
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def sync_cc_pair_groups(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[None]:
|
||||
"""Triggers group sync on a particular cc_pair immediately"""
|
||||
@@ -141,18 +140,18 @@ def sync_cc_pair_groups(
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"Connection not found for current user's permissions",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
"External group sync task already in progress.",
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="External group sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -165,9 +164,9 @@ def sync_cc_pair_groups(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"External group sync task creation failed.",
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
@@ -147,7 +147,7 @@ class ConfluenceCloudOAuth:
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
@@ -259,7 +259,7 @@ def confluence_oauth_callback(
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
@@ -326,7 +326,7 @@ def confluence_oauth_finalize(
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -115,7 +115,7 @@ class GoogleDriveOAuth:
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -99,7 +99,7 @@ class SlackOAuth:
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
|
||||
@@ -154,7 +154,7 @@ def snapshot_from_chat_session(
|
||||
@router.get("/admin/chat-sessions")
|
||||
def admin_get_chat_sessions(
|
||||
user_id: UUID,
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
@@ -197,7 +197,7 @@ def get_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -235,7 +235,7 @@ def get_chat_session_history(
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID,
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -270,7 +270,7 @@ def get_chat_session_admin(
|
||||
|
||||
@router.get("/admin/query-history/list")
|
||||
def list_all_query_history_exports(
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryHistoryExport]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -298,7 +298,7 @@ def list_all_query_history_exports(
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
@@ -345,7 +345,7 @@ def start_query_history_export(
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -379,7 +379,7 @@ def get_query_history_export_status(
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User = Depends(require_permission(Permission.READ_QUERY_HISTORY)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,6 +24,7 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -59,18 +62,31 @@ from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import user_is_admin
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -209,12 +225,37 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
@@ -494,6 +535,7 @@ def create_user(
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
@@ -581,7 +623,7 @@ def replace_user(
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=user_is_admin(user)
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
@@ -681,7 +723,7 @@ def patch_user(
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=user_is_admin(user)
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
|
||||
@@ -5,9 +5,10 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
@@ -47,14 +48,15 @@ def get_all_group_token_limit_settings(
|
||||
@router.get("/user-group/{group_id}")
|
||||
def get_group_token_limit_settings(
|
||||
group_id: int,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits_for_group(
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
|
||||
db_session=db_session,
|
||||
group_id=group_id,
|
||||
user=user,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -63,7 +65,7 @@ def get_group_token_limit_settings(
|
||||
def create_group_token_limit_settings(
|
||||
group_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,26 +13,28 @@ from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import rename_user_group
|
||||
from ee.onyx.db.user_group import set_group_permissions_bulk__no_commit
|
||||
from ee.onyx.db.user_group import set_group_permission__no_commit
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import BulkSetPermissionsRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionRequest
|
||||
from ee.onyx.server.user_group.models import SetPermissionResponse
|
||||
from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupRename
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import NON_TOGGLEABLE_PERMISSIONS
|
||||
from onyx.auth.permissions import PERMISSION_REGISTRY
|
||||
from onyx.auth.permissions import PermissionRegistryEntry
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
@@ -45,15 +48,24 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
_: User = Depends(require_permission(Permission.READ_USER_GROUPS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -63,7 +75,7 @@ def list_minimal_user_groups(
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in get_effective_permissions(user):
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
@@ -80,71 +92,62 @@ def list_minimal_user_groups(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/permissions/registry")
|
||||
def get_permission_registry(
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
) -> list[PermissionRegistryEntry]:
|
||||
return PERMISSION_REGISTRY
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[Permission]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted and grant.permission not in NON_TOGGLEABLE_PERMISSIONS
|
||||
grant.permission for grant in group.permission_grants if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.put("/admin/user-group/{user_group_id}/permissions")
|
||||
def set_user_group_permissions(
|
||||
def set_user_group_permission(
|
||||
user_group_id: int,
|
||||
request: BulkSetPermissionsRequest,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
request: SetPermissionRequest,
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[Permission]:
|
||||
) -> SetPermissionResponse:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
|
||||
non_toggleable = [p for p in request.permissions if p in NON_TOGGLEABLE_PERMISSIONS]
|
||||
if non_toggleable:
|
||||
if request.permission in NON_TOGGLEABLE_PERMISSIONS:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Permissions {non_toggleable} cannot be toggled via this endpoint",
|
||||
f"Permission '{request.permission}' cannot be toggled via this endpoint",
|
||||
)
|
||||
|
||||
result = set_group_permissions_bulk__no_commit(
|
||||
set_group_permission__no_commit(
|
||||
group_id=user_group_id,
|
||||
desired_permissions=set(request.permissions),
|
||||
permission=request.permission,
|
||||
enabled=request.enabled,
|
||||
granted_by=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return result
|
||||
return SetPermissionResponse(permission=request.permission, enabled=request.enabled)
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
db_user_group = insert_user_group(db_session, user_group)
|
||||
except IntegrityError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"User group with name '{user_group.name}' already exists. Please "
|
||||
"choose a different name.",
|
||||
+ "choose a different name.",
|
||||
)
|
||||
return UserGroup.from_model(db_user_group)
|
||||
|
||||
@@ -152,7 +155,7 @@ def create_user_group(
|
||||
@router.patch("/admin/user-group/rename")
|
||||
def rename_user_group_endpoint(
|
||||
rename_request: UserGroupRename,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
@@ -182,7 +185,7 @@ def rename_user_group_endpoint(
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
user_group_update: UserGroupUpdate,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -195,14 +198,14 @@ def patch_user_group(
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/user-group/{user_group_id}/add-users")
|
||||
def add_users(
|
||||
user_group_id: int,
|
||||
add_users_request: AddUsersToUserGroupRequest,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -215,13 +218,32 @@ def add_users(
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/user-group/{user_group_id}/set-curator")
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_user_curator_relationship(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
@@ -230,7 +252,7 @@ def delete_user_group(
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
@@ -242,7 +264,7 @@ def delete_user_group(
|
||||
def update_group_agents(
|
||||
user_group_id: int,
|
||||
request: UpdateGroupAgentsRequest,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_USER_GROUPS)),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
for agent_id in request.added_agent_ids:
|
||||
|
||||
@@ -17,6 +17,7 @@ class UserGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
users: list[UserInfo]
|
||||
curator_ids: list[UUID]
|
||||
cc_pairs: list[ConnectorCredentialPairDescriptor]
|
||||
document_sets: list[DocumentSet]
|
||||
personas: list[PersonaSnapshot]
|
||||
@@ -36,7 +37,7 @@ class UserGroup(BaseModel):
|
||||
is_active=user.is_active,
|
||||
is_superuser=user.is_superuser,
|
||||
is_verified=user.is_verified,
|
||||
account_type=user.account_type,
|
||||
role=user.role,
|
||||
preferences=UserPreferences(
|
||||
default_model=user.default_model,
|
||||
chosen_assistants=user.chosen_assistants,
|
||||
@@ -44,6 +45,11 @@ class UserGroup(BaseModel):
|
||||
)
|
||||
for user in user_group_model.users
|
||||
],
|
||||
curator_ids=[
|
||||
user.user_id
|
||||
for user in user_group_model.user_group_relationships
|
||||
if user.is_curator and user.user_id is not None
|
||||
],
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairDescriptor(
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
@@ -108,6 +114,11 @@ class UserGroupRename(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class SetCuratorRequest(BaseModel):
|
||||
user_id: UUID
|
||||
is_curator: bool
|
||||
|
||||
|
||||
class UpdateGroupAgentsRequest(BaseModel):
|
||||
added_agent_ids: list[int]
|
||||
removed_agent_ids: list[int]
|
||||
@@ -121,7 +132,3 @@ class SetPermissionRequest(BaseModel):
|
||||
class SetPermissionResponse(BaseModel):
|
||||
permission: Permission
|
||||
enabled: bool
|
||||
|
||||
|
||||
class BulkSetPermissionsRequest(BaseModel):
|
||||
permissions: list[Permission]
|
||||
|
||||
@@ -2,11 +2,11 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import ANONYMOUS_USER_INFO_ID
|
||||
from onyx.configs.constants import KV_ANONYMOUS_USER_PERSONALIZATION_KEY
|
||||
from onyx.configs.constants import KV_ANONYMOUS_USER_PREFERENCES_KEY
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.key_value_store.store import KeyValueStore
|
||||
from onyx.key_value_store.store import KvKeyNotFoundError
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -55,7 +55,7 @@ def fetch_anonymous_user_info(store: KeyValueStore) -> UserInfo:
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
role=UserRole.LIMITED,
|
||||
preferences=load_anonymous_user_preferences(store),
|
||||
personalization=personalization,
|
||||
is_anonymous_user=True,
|
||||
|
||||
@@ -10,9 +10,9 @@ from pydantic import BaseModel
|
||||
from onyx.auth.constants import API_KEY_LENGTH
|
||||
from onyx.auth.constants import API_KEY_PREFIX
|
||||
from onyx.auth.constants import DEPRECATED_API_KEY_PREFIX
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from onyx.server.models import UserGroupInfo
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class ApiKeyDescriptor(BaseModel):
|
||||
api_key_display: str
|
||||
api_key: str | None = None # only present on initial creation
|
||||
api_key_name: str | None = None
|
||||
groups: list[UserGroupInfo]
|
||||
api_key_role: UserRole
|
||||
|
||||
user_id: uuid.UUID
|
||||
|
||||
|
||||
@@ -11,15 +11,13 @@ from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -31,14 +29,14 @@ IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_USER_GROUPS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
@@ -46,15 +44,6 @@ IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
Permission.READ_USER_GROUPS.value,
|
||||
},
|
||||
Permission.MANAGE_LLMS.value: {
|
||||
Permission.READ_USER_GROUPS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS.value: {
|
||||
Permission.READ_USER_GROUPS.value,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -69,138 +58,9 @@ NON_TOGGLEABLE_PERMISSIONS: frozenset[Permission] = frozenset(
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
Permission.READ_USER_GROUPS,
|
||||
}
|
||||
)
|
||||
|
||||
# Permissions auto-granted to all users in Community Edition.
|
||||
# In CE there is no group-permission UI, so these capabilities must be
|
||||
# available without explicit grants. In EE they are controlled normally
|
||||
# via group permissions.
|
||||
CE_UNGATED_PERMISSIONS: frozenset[Permission] = frozenset(
|
||||
{
|
||||
Permission.ADD_AGENTS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PermissionRegistryEntry(BaseModel):
|
||||
"""A UI-facing permission row served by GET /admin/permissions/registry.
|
||||
|
||||
The field_validator ensures non-toggleable permissions (BASIC_ACCESS,
|
||||
FULL_ADMIN_PANEL_ACCESS, READ_*) can never appear in the registry.
|
||||
"""
|
||||
|
||||
id: str
|
||||
display_name: str
|
||||
description: str
|
||||
permissions: list[Permission]
|
||||
group: int
|
||||
|
||||
@field_validator("permissions")
|
||||
@classmethod
|
||||
def must_be_toggleable(cls, v: list[Permission]) -> list[Permission]:
|
||||
for p in v:
|
||||
if p in NON_TOGGLEABLE_PERMISSIONS:
|
||||
raise ValueError(
|
||||
f"Permission '{p.value}' is not toggleable and "
|
||||
"cannot be included in the permission registry"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
# Registry of toggleable permissions exposed to the admin UI.
|
||||
# Single source of truth for display names, descriptions, grouping,
|
||||
# and which backend tokens each UI row controls.
|
||||
# The frontend fetches this via GET /admin/permissions/registry
|
||||
# and only adds icon mapping locally.
|
||||
PERMISSION_REGISTRY: list[PermissionRegistryEntry] = [
|
||||
# Group 0 — System Configuration
|
||||
PermissionRegistryEntry(
|
||||
id="manage_llms",
|
||||
display_name="Manage LLMs",
|
||||
description="Add and update configurations for language models (LLMs).",
|
||||
permissions=[Permission.MANAGE_LLMS],
|
||||
group=0,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="manage_connectors_and_document_sets",
|
||||
display_name="Manage Connectors & Document Sets",
|
||||
description="Add and update connectors and document sets.",
|
||||
permissions=[
|
||||
Permission.MANAGE_CONNECTORS,
|
||||
Permission.MANAGE_DOCUMENT_SETS,
|
||||
],
|
||||
group=0,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="manage_actions",
|
||||
display_name="Manage Actions",
|
||||
description="Add and update custom tools and MCP/OpenAPI actions.",
|
||||
permissions=[Permission.MANAGE_ACTIONS],
|
||||
group=0,
|
||||
),
|
||||
# Group 1 — User & Access Management
|
||||
PermissionRegistryEntry(
|
||||
id="manage_groups",
|
||||
display_name="Manage Groups",
|
||||
description="Add and update user groups.",
|
||||
permissions=[Permission.MANAGE_USER_GROUPS],
|
||||
group=1,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="manage_service_accounts",
|
||||
display_name="Manage Service Accounts",
|
||||
description="Add and update service accounts and their API keys.",
|
||||
permissions=[Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS],
|
||||
group=1,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="manage_bots",
|
||||
display_name="Manage Slack/Discord Bots",
|
||||
description="Add and update Onyx integrations with Slack or Discord.",
|
||||
permissions=[Permission.MANAGE_BOTS],
|
||||
group=1,
|
||||
),
|
||||
# Group 2 — Agents
|
||||
PermissionRegistryEntry(
|
||||
id="create_agents",
|
||||
display_name="Create Agents",
|
||||
description="Create and edit the user's own agents.",
|
||||
permissions=[Permission.ADD_AGENTS],
|
||||
group=2,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="manage_agents",
|
||||
display_name="Manage Agents",
|
||||
description="View and update all public and shared agents in the organization.",
|
||||
permissions=[Permission.MANAGE_AGENTS],
|
||||
group=2,
|
||||
),
|
||||
# Group 3 — Monitoring & Tokens
|
||||
PermissionRegistryEntry(
|
||||
id="view_agent_analytics",
|
||||
display_name="View Agent Analytics",
|
||||
description="View analytics for agents the group can manage.",
|
||||
permissions=[Permission.READ_AGENT_ANALYTICS],
|
||||
group=3,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="view_query_history",
|
||||
display_name="View Query History",
|
||||
description="View query history of everyone in the organization.",
|
||||
permissions=[Permission.READ_QUERY_HISTORY],
|
||||
group=3,
|
||||
),
|
||||
PermissionRegistryEntry(
|
||||
id="create_user_access_token",
|
||||
display_name="Create User Access Token",
|
||||
description="Add and update the user's personal access tokens.",
|
||||
permissions=[Permission.CREATE_USER_API_KEYS],
|
||||
group=3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
@@ -223,12 +83,7 @@ def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions.
|
||||
|
||||
Admin-role users always receive all permissions regardless of the JSONB
|
||||
column, maintaining backward compatibility with role-based access control.
|
||||
"""
|
||||
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
@@ -237,26 +92,10 @@ def get_effective_permissions(user: User) -> set[Permission]:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
granted |= CE_UNGATED_PERMISSIONS
|
||||
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def has_permission(user: User, permission: Permission) -> bool:
|
||||
"""Check whether *user* holds *permission* (directly or via implication/admin override)."""
|
||||
return permission in get_effective_permissions(user)
|
||||
|
||||
|
||||
def _get_current_user() -> Any:
|
||||
"""Lazy import to break circular dependency between permissions and users modules."""
|
||||
from onyx.auth.users import current_user
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
@@ -268,9 +107,7 @@ def require_permission(
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(
|
||||
user: User = Depends(_get_current_user()),
|
||||
) -> User:
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
|
||||
@@ -38,10 +38,11 @@ class UserRole(str, Enum):
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
account_type: AccountType
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
@@ -66,8 +67,10 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""Intentionally empty: keeps account_type and permissions out of the
|
||||
fastapi-users PATCH endpoints."""
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
|
||||
class AuthBackend(str, Enum):
|
||||
|
||||
@@ -77,9 +77,9 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.jwt import verify_jwt_token
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -114,12 +114,12 @@ from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
@@ -158,8 +158,7 @@ REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
|
||||
return has_permission(user, Permission.FULL_ADMIN_PANEL_ACCESS)
|
||||
return user.role == UserRole.ADMIN
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
@@ -361,7 +360,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)(user_email)
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserDatabase[User, uuid.UUID](
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
user = await tenant_user_db.get_by_email(user_email)
|
||||
@@ -457,16 +456,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Single-tenant: Check invite list (skips if SAML/OIDC or no list configured)
|
||||
verify_email_is_invited(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserDatabase[User, uuid.UUID](
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
self.user_db = tenant_user_db
|
||||
|
||||
user_count = await get_user_count()
|
||||
is_admin = (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
)
|
||||
if hasattr(user_create, "role"):
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
@@ -509,7 +512,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create, is_admin)
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
@@ -537,7 +540,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create, is_admin)
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
@@ -582,7 +585,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
is_admin: bool,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
@@ -596,11 +598,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=is_admin,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
@@ -682,7 +685,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# NOTE(rkuo): If this UserManager is instantiated per connection
|
||||
# should we even be doing this here?
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserDatabase[User, uuid.UUID](
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
)
|
||||
self.user_db = tenant_user_db
|
||||
@@ -775,8 +778,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading — otherwise they'd become an inactive user who
|
||||
# still can't log in.
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
@@ -788,6 +791,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
@@ -931,7 +935,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": user.email,
|
||||
"onyx_cloud_user_id": str(user.id),
|
||||
"tenant_id": str(tenant_id) if tenant_id else None,
|
||||
"account_type": user.account_type.value,
|
||||
"role": user.role.value,
|
||||
"is_first_user": user_count == 1,
|
||||
"source": "marketing_site_signup",
|
||||
"conversion_timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
@@ -1519,7 +1523,7 @@ async def _get_or_create_user_from_jwt(
|
||||
verify_email_is_invited(email)
|
||||
verify_email_domain(email)
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID] = SQLAlchemyUserDatabase(
|
||||
user_db: SQLAlchemyUserAdminDB[User, uuid.UUID] = SQLAlchemyUserAdminDB(
|
||||
async_db_session, User, OAuthAccount
|
||||
)
|
||||
user_manager = UserManager(user_db)
|
||||
@@ -1611,6 +1615,7 @@ def get_anonymous_user() -> User:
|
||||
is_active=True,
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
@@ -1684,6 +1689,18 @@ async def current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User = Depends(current_user),
|
||||
) -> User:
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -94,6 +95,17 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -16,6 +16,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -36,6 +42,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -50,6 +57,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -90,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -322,6 +322,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.server.features.proposal_review.engine",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
@@ -30,6 +31,8 @@ from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -130,6 +133,7 @@ def _extract_from_batch(
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
connector_type: str = "unknown",
|
||||
) -> SlimConnectorExtractionResult:
|
||||
"""
|
||||
Extract document IDs and hierarchy nodes from a runnable connector.
|
||||
@@ -179,21 +183,38 @@ def extract_ids_from_runnable_connector(
|
||||
)
|
||||
|
||||
# process raw batches to extract both IDs and hierarchy nodes
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
enumeration_start = time.monotonic()
|
||||
try:
|
||||
for doc_list in raw_batch_generator:
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
except Exception as e:
|
||||
# Best-effort rate limit detection via string matching.
|
||||
# Connectors surface rate limits inconsistently — some raise HTTP 429,
|
||||
# some use SDK-specific exceptions (e.g. google.api_core.exceptions.ResourceExhausted)
|
||||
# that may or may not include "rate limit" or "429" in the message.
|
||||
# TODO(Bo): replace with a standard ConnectorRateLimitError exception that all
|
||||
# connectors raise when rate limited, making this check precise.
|
||||
error_str = str(e)
|
||||
if "rate limit" in error_str.lower() or "429" in error_str:
|
||||
inc_pruning_rate_limit_error(connector_type)
|
||||
raise
|
||||
finally:
|
||||
observe_pruning_enumeration_duration(
|
||||
time.monotonic() - enumeration_start, connector_type
|
||||
)
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
|
||||
@@ -79,6 +79,15 @@ beat_task_templates: list[dict] = [
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-dangling-import-jobs",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DANGLING_IMPORT_JOBS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
|
||||
@@ -59,6 +59,11 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -102,7 +107,7 @@ def revoke_tasks_blocking_deletion(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
@@ -110,7 +115,7 @@ def revoke_tasks_blocking_deletion(
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
@@ -300,6 +305,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,11 +313,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -359,6 +367,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -508,7 +517,11 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
if not connector:
|
||||
task_logger.info(
|
||||
"Connector deletion - Connector already deleted, skipping connector cleanup"
|
||||
)
|
||||
elif not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
@@ -523,6 +536,12 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -541,6 +560,11 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -717,5 +741,6 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -172,6 +172,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -325,6 +329,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
@@ -38,6 +38,7 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -72,6 +73,7 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
@@ -524,6 +526,14 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
|
||||
# The session is closed before enumeration so the DB connection is not held
|
||||
# open during the 10–30+ minute connector crawl.
|
||||
connector_source: DocumentSource | None = None
|
||||
connector_type: str = ""
|
||||
is_connector_public: bool = False
|
||||
runnable_connector: BaseConnector | None = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -549,48 +559,51 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
connector_source = cc_pair.connector.source
|
||||
connector_type = connector_source.value
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
connector_source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
# Session 1 closed here — connection released before enumeration.
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
# Extract docs and hierarchy nodes from the source (no DB session held).
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
source = connector_source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
commit=True,
|
||||
commit=False,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
@@ -599,9 +612,13 @@ def connector_pruning_generator_task(
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Single commit so the FK reference in the join table can never
|
||||
# outrun the parent hierarchy_node insert.
|
||||
db_session.commit()
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -636,40 +653,46 @@ def connector_pruning_generator_task(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
diff_start = time.monotonic()
|
||||
try:
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={connector_source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.prune.generate_tasks(
|
||||
set(doc_ids_to_remove), self.app, db_session, None
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
finally:
|
||||
observe_pruning_diff_duration(
|
||||
time.monotonic() - diff_start, connector_type
|
||||
)
|
||||
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
error_type: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
@@ -37,4 +39,5 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
error_type=model.error_type,
|
||||
)
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,8 +4,6 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -635,7 +633,6 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1020,20 +1017,16 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
persisted_memory_id = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,7 +67,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1006,93 +1005,86 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -449,6 +449,7 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
OPENSEARCH_MIGRATION_BEAT_LOCK = "da_lock:opensearch_migration_beat"
|
||||
CHECK_DANGLING_IMPORT_JOBS_BEAT_LOCK = "da_lock:check_dangling_import_jobs_beat"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
@@ -612,6 +613,9 @@ class OnyxCeleryTask:
|
||||
# Hook execution log retention
|
||||
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
|
||||
|
||||
# Proposal review import cleanup
|
||||
CHECK_FOR_DANGLING_IMPORT_JOBS = "check_for_dangling_import_jobs"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
@@ -61,6 +61,9 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 5
|
||||
|
||||
_SERVER_ERROR_CODES = {500, 502, 503, 504}
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
@@ -569,7 +572,8 @@ class OnyxConfluence:
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
|
||||
current_limit = limit
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
@@ -609,40 +613,61 @@ class OnyxConfluence:
|
||||
)
|
||||
continue
|
||||
|
||||
# If we fail due to a 500, try one by one.
|
||||
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
|
||||
# pagination
|
||||
if raw_response.status_code == 500 and not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
if initial_start is None:
|
||||
# can't handle this if we don't have offset-based pagination
|
||||
raise
|
||||
if raw_response.status_code in _SERVER_ERROR_CODES:
|
||||
# Try reducing the page size -- Confluence often times out
|
||||
# on large result sets (especially Cloud 504s).
|
||||
if current_limit > _MINIMUM_PAGINATION_LIMIT:
|
||||
old_limit = current_limit
|
||||
current_limit = max(
|
||||
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
|
||||
)
|
||||
logger.warning(
|
||||
f"Confluence returned {raw_response.status_code}. "
|
||||
f"Reducing limit from {old_limit} to {current_limit} "
|
||||
f"and retrying."
|
||||
)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
continue
|
||||
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=limit,
|
||||
)
|
||||
# Limit reduction exhausted -- for Server, fall back to
|
||||
# one-by-one offset pagination as a last resort.
|
||||
if not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = (
|
||||
yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=current_limit,
|
||||
)
|
||||
)
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
f"Error in confluence call to {url_suffix} "
|
||||
f"after reducing limit to {current_limit}.\n"
|
||||
f"Raw Response Text: {raw_response.text}\n"
|
||||
f"Error: {e}\n"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
@@ -680,6 +705,10 @@ class OnyxConfluence:
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
if url_suffix and current_limit != limit:
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
for i, result in enumerate(results):
|
||||
updated_start += 1
|
||||
if url_suffix and next_page_callback and i == len(results) - 1:
|
||||
|
||||
@@ -42,6 +42,9 @@ from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
@@ -70,11 +73,13 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -202,7 +207,9 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1665,6 +1672,82 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
]
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
|
||||
|
||||
for doc_id, error in batch_result.errors.items():
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=doc_id,
|
||||
),
|
||||
failure_message=f"Failed to retrieve file during error resolution: {error}",
|
||||
exception=error,
|
||||
)
|
||||
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
retrieved_files = [
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in batch_result.files.values()
|
||||
]
|
||||
|
||||
yield from self._get_new_ancestors_for_files(
|
||||
files=retrieved_files,
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(rf, permission_sync_context),
|
||||
)
|
||||
for rf in retrieved_files
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
for result in results:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
|
||||
@@ -9,6 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -60,6 +61,8 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
|
||||
|
||||
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
|
||||
@@ -216,7 +219,7 @@ def get_external_access_for_folder(
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
"""Get the appropriate fields string for files().list() based on the field type enum."""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
@@ -225,6 +228,25 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _extract_single_file_fields(list_fields: str) -> str:
|
||||
"""Convert a files().list() fields string to one suitable for files().get().
|
||||
|
||||
List fields look like "nextPageToken, files(field1, field2, ...)"
|
||||
Single-file fields should be just "field1, field2, ..."
|
||||
"""
|
||||
start = list_fields.find("files(")
|
||||
if start == -1:
|
||||
return list_fields
|
||||
inner_start = start + len("files(")
|
||||
inner_end = list_fields.rfind(")")
|
||||
return list_fields[inner_start:inner_end]
|
||||
|
||||
|
||||
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().get() based on the field type enum."""
|
||||
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
@@ -536,3 +558,74 @@ def get_file_by_web_view_link(
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class BatchRetrievalResult:
|
||||
"""Result of a batch file retrieval, separating successes from errors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[str, GoogleDriveFileType] = {}
|
||||
self.errors: dict[str, Exception] = {}
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
|
||||
|
||||
Returns a BatchRetrievalResult containing successful file retrievals
|
||||
and errors for any files that could not be fetched.
|
||||
Automatically splits into chunks of MAX_BATCH_SIZE.
|
||||
"""
|
||||
fields = _get_single_file_fields(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
combined = BatchRetrievalResult()
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
|
||||
combined.files.update(chunk_result.files)
|
||||
combined.errors.update(chunk_result.errors)
|
||||
return combined
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Single-batch implementation."""
|
||||
|
||||
result = BatchRetrievalResult()
|
||||
|
||||
def callback(
|
||||
request_id: str,
|
||||
response: GoogleDriveFileType,
|
||||
exception: Exception | None,
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
result.errors[request_id] = exception
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
result.errors[web_view_link] = e
|
||||
|
||||
batch.execute()
|
||||
return result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -53,6 +54,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -162,12 +178,13 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -188,12 +205,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
return GoogleAppCredentials(**creds)
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -201,10 +218,14 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -220,12 +241,14 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -234,12 +257,14 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.json(),
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -298,6 +298,22 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
|
||||
|
||||
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
May also yield HierarchyNode objects for ancestor folders of resolved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HierarchyConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_hierarchy(
|
||||
|
||||
@@ -8,6 +8,7 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
@@ -40,6 +41,7 @@ from onyx.connectors.jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
from onyx.connectors.jira.utils import build_jira_url
|
||||
from onyx.connectors.jira.utils import CustomFieldExtractor
|
||||
from onyx.connectors.jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.jira.utils import get_comment_strs
|
||||
from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION
|
||||
@@ -52,6 +54,7 @@ from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -60,8 +63,11 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
_MAX_ATTACHMENT_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -255,15 +261,13 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
try:
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,12 +281,25 @@ def bulk_fetch_issues(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
@@ -364,6 +381,7 @@ def process_jira_issue(
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
custom_fields_mapping: dict[str, str] | None = None,
|
||||
) -> Document | None:
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
@@ -449,6 +467,24 @@ def process_jira_issue(
|
||||
else:
|
||||
logger.error(f"Project should exist but does not for {issue.key}")
|
||||
|
||||
# Merge custom fields into metadata if a mapping was provided
|
||||
if custom_fields_mapping:
|
||||
try:
|
||||
custom_fields = CustomFieldExtractor.get_issue_custom_fields(
|
||||
issue, custom_fields_mapping
|
||||
)
|
||||
# Filter out custom fields that collide with existing metadata keys
|
||||
for key in list(custom_fields.keys()):
|
||||
if key in metadata_dict:
|
||||
logger.warning(
|
||||
f"Custom field '{key}' on {issue.key} collides with "
|
||||
f"standard metadata key; skipping custom field value"
|
||||
)
|
||||
del custom_fields[key]
|
||||
metadata_dict.update(custom_fields)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract custom fields for {issue.key}: {e}")
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
@@ -491,6 +527,12 @@ class JiraConnector(
|
||||
# Custom JQL query to filter Jira issues
|
||||
jql_query: str | None = None,
|
||||
scoped_token: bool = False,
|
||||
# When True, extract custom fields from Jira issues and include them
|
||||
# in document metadata with human-readable field names.
|
||||
extract_custom_fields: bool = False,
|
||||
# When True, download attachments from Jira issues and yield them
|
||||
# as separate Documents linked to the parent ticket.
|
||||
fetch_attachments: bool = False,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -504,7 +546,11 @@ class JiraConnector(
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.jql_query = jql_query
|
||||
self.scoped_token = scoped_token
|
||||
self.extract_custom_fields = extract_custom_fields
|
||||
self.fetch_attachments = fetch_attachments
|
||||
self._jira_client: JIRA | None = None
|
||||
# Mapping of custom field IDs to human-readable names (populated on load_credentials)
|
||||
self._custom_fields_mapping: dict[str, str] = {}
|
||||
# Cache project permissions to avoid fetching them repeatedly across runs
|
||||
self._project_permissions_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -665,12 +711,134 @@ class JiraConnector(
|
||||
# the document belongs directly under the project in the hierarchy
|
||||
return project_key
|
||||
|
||||
def _process_attachments(
|
||||
self,
|
||||
issue: Issue,
|
||||
parent_hierarchy_raw_node_id: str | None,
|
||||
include_permissions: bool = False,
|
||||
project_key: str | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Download and yield Documents for each attachment on a Jira issue.
|
||||
|
||||
Each attachment becomes a separate Document whose text is extracted
|
||||
from the downloaded file content. Failures on individual attachments
|
||||
are logged and yielded as ConnectorFailure so they never break the
|
||||
overall indexing run.
|
||||
"""
|
||||
attachments = best_effort_get_field_from_issue(issue, "attachment")
|
||||
if not attachments:
|
||||
return
|
||||
|
||||
issue_url = build_jira_url(self.jira_base, issue.key)
|
||||
|
||||
for attachment in attachments:
|
||||
try:
|
||||
filename = getattr(attachment, "filename", "unknown")
|
||||
try:
|
||||
size = int(getattr(attachment, "size", 0) or 0)
|
||||
except (ValueError, TypeError):
|
||||
size = 0
|
||||
content_url = getattr(attachment, "content", None)
|
||||
attachment_id = getattr(attachment, "id", filename)
|
||||
mime_type = getattr(attachment, "mimeType", "application/octet-stream")
|
||||
created = getattr(attachment, "created", None)
|
||||
|
||||
if size > _MAX_ATTACHMENT_SIZE_BYTES:
|
||||
logger.warning(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"size {size} bytes exceeds {_MAX_ATTACHMENT_SIZE_BYTES} byte limit"
|
||||
)
|
||||
continue
|
||||
|
||||
if not content_url:
|
||||
logger.warning(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"no content URL available"
|
||||
)
|
||||
continue
|
||||
|
||||
# Download the attachment using the public API on the
|
||||
# python-jira Attachment resource (avoids private _session access
|
||||
# and the double-copy from response.content + BytesIO wrapping).
|
||||
file_content = attachment.get()
|
||||
|
||||
# Extract text from the downloaded file
|
||||
try:
|
||||
text = extract_file_text(
|
||||
file=BytesIO(file_content),
|
||||
file_name=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not extract text from attachment '{filename}' "
|
||||
f"on {issue.key}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not text or not text.strip():
|
||||
logger.info(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"no text content could be extracted"
|
||||
)
|
||||
continue
|
||||
|
||||
doc_id = f"{issue_url}/attachments/{attachment_id}"
|
||||
attachment_doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(link=issue_url, text=text)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {filename}",
|
||||
title=filename,
|
||||
doc_updated_at=(time_str_to_utc(created) if created else None),
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
metadata={
|
||||
"parent_ticket": issue.key,
|
||||
"attachment_filename": filename,
|
||||
"attachment_mime_type": mime_type,
|
||||
"attachment_size": str(size),
|
||||
},
|
||||
)
|
||||
if include_permissions and project_key:
|
||||
attachment_doc.external_access = self._get_project_permissions(
|
||||
project_key,
|
||||
add_prefix=True,
|
||||
)
|
||||
yield attachment_doc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process attachment on {issue.key}: {e}")
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{issue_url}/attachments/{getattr(attachment, 'id', 'unknown')}",
|
||||
document_link=issue_url,
|
||||
),
|
||||
failure_message=f"Failed to process attachment '{getattr(attachment, 'filename', 'unknown')}': {str(e)}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
|
||||
# Fetch the custom field ID-to-name mapping once at credential load time.
|
||||
# This avoids repeated API calls during issue processing.
|
||||
if self.extract_custom_fields:
|
||||
try:
|
||||
self._custom_fields_mapping = (
|
||||
CustomFieldExtractor.get_all_custom_fields(self._jira_client)
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {len(self._custom_fields_mapping)} custom field definitions"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch custom field definitions; "
|
||||
f"custom field extraction will be skipped: {e}"
|
||||
)
|
||||
self._custom_fields_mapping = {}
|
||||
|
||||
return None
|
||||
|
||||
def _get_jql_query(
|
||||
@@ -801,6 +969,11 @@ class JiraConnector(
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
custom_fields_mapping=(
|
||||
self._custom_fields_mapping
|
||||
if self._custom_fields_mapping
|
||||
else None
|
||||
),
|
||||
):
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
@@ -810,6 +983,15 @@ class JiraConnector(
|
||||
)
|
||||
yield document
|
||||
|
||||
# Yield attachment documents if enabled
|
||||
if self.fetch_attachments:
|
||||
yield from self._process_attachments(
|
||||
issue=issue,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
include_permissions=include_permissions,
|
||||
project_key=project_key,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
@@ -917,20 +1099,41 @@ class JiraConnector(
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
doc_id = build_jira_url(self.jira_base, issue_key)
|
||||
|
||||
parent_hierarchy_raw_node_id = (
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
)
|
||||
project_perms = self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
)
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=doc_id,
|
||||
# Permission sync path - don't prefix, upsert_document_external_perms handles it
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
external_access=project_perms,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Also emit SlimDocument entries for each attachment
|
||||
if self.fetch_attachments:
|
||||
attachments = best_effort_get_field_from_issue(issue, "attachment")
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
attachment_id = getattr(
|
||||
attachment,
|
||||
"id",
|
||||
getattr(attachment, "filename", "unknown"),
|
||||
)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"{doc_id}/attachments/{attachment_id}",
|
||||
external_access=project_perms,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -33,9 +34,17 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
class SectionType(str, Enum):
|
||||
"""Discriminator for Section subclasses."""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
type: SectionType
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
@@ -44,6 +53,7 @@ class Section(BaseModel):
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
type: Literal[SectionType.TEXT] = SectionType.TEXT
|
||||
text: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -53,6 +63,7 @@ class TextSection(Section):
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
|
||||
image_file_id: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -134,7 +145,6 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
|
||||
|
||||
first_name = cast(str, model_dict.get("FirstName"))
|
||||
last_name = cast(str, model_dict.get("LastName"))
|
||||
email = cast(str, model_dict.get("Email"))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -6,6 +7,14 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -432,7 +519,6 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -662,7 +748,6 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -695,62 +780,37 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
for msg in replies:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -976,7 +1036,16 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -993,8 +1062,16 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1010,7 +1087,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -668,6 +695,15 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -684,7 +720,9 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -702,7 +740,8 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
return [
|
||||
search_queries = [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -12,6 +11,7 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
@@ -21,8 +21,8 @@ from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.server.models import UserGroupInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -37,57 +37,6 @@ def is_api_key_email_address(email: str) -> bool:
|
||||
return email.endswith(get_api_key_email_pattern())
|
||||
|
||||
|
||||
def _get_user_groups(db_session: Session, user_id: uuid.UUID) -> list[UserGroupInfo]:
|
||||
"""Get lightweight group info for a user."""
|
||||
groups = (
|
||||
db_session.scalars(
|
||||
select(UserGroup)
|
||||
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
|
||||
.where(User__UserGroup.user_id == user_id)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
return [UserGroupInfo(id=g.id, name=g.name) for g in groups]
|
||||
|
||||
|
||||
def _set_user_groups__no_commit(
|
||||
db_session: Session,
|
||||
user_id: uuid.UUID,
|
||||
group_ids: list[int],
|
||||
) -> None:
|
||||
"""Replace all group memberships for a user with the given group_ids.
|
||||
Does NOT commit."""
|
||||
if group_ids:
|
||||
# Validate that all requested group IDs exist
|
||||
existing_ids = set(
|
||||
db_session.scalars(
|
||||
select(UserGroup.id).where(UserGroup.id.in_(group_ids))
|
||||
).all()
|
||||
)
|
||||
missing = set(group_ids) - existing_ids
|
||||
if missing:
|
||||
raise ValueError(f"Group IDs do not exist: {sorted(missing)}")
|
||||
|
||||
# Remove all existing memberships
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
)
|
||||
|
||||
# Add new memberships
|
||||
if group_ids:
|
||||
insert_stmt = (
|
||||
pg_insert(User__UserGroup)
|
||||
.values([{"user_id": user_id, "user_group_id": gid} for gid in group_ids])
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id]
|
||||
)
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
recompute_user_permissions__no_commit(user_id, db_session)
|
||||
|
||||
|
||||
def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
|
||||
api_keys = (
|
||||
db_session.scalars(select(ApiKey).options(joinedload(ApiKey.user)))
|
||||
@@ -97,10 +46,10 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
|
||||
return [
|
||||
ApiKeyDescriptor(
|
||||
api_key_id=api_key.id,
|
||||
api_key_role=api_key.user.role,
|
||||
api_key_display=api_key.api_key_display,
|
||||
api_key_name=api_key.name,
|
||||
user_id=api_key.user_id,
|
||||
groups=_get_user_groups(db_session, api_key.user_id),
|
||||
)
|
||||
for api_key in api_keys
|
||||
]
|
||||
@@ -145,6 +94,7 @@ def insert_api_key(
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
@@ -158,18 +108,25 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the service account to the specified groups
|
||||
_set_user_groups__no_commit(db_session, api_key_user_id, api_key_args.group_ids)
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# Only ADMIN and BASIC roles get default group membership.
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
api_key_display=api_key_row.api_key_display,
|
||||
api_key=api_key,
|
||||
api_key_name=api_key_args.name,
|
||||
user_id=api_key_user_id,
|
||||
groups=_get_user_groups(db_session, api_key_user_id),
|
||||
)
|
||||
|
||||
|
||||
@@ -190,8 +147,31 @@ def update_api_key(
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
# Replace all group memberships with the specified groups
|
||||
_set_user_groups__no_commit(db_session, api_key_user.id, api_key_args.group_ids)
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (only for ADMIN/BASIC).
|
||||
if api_key_args.role in (UserRole.ADMIN, UserRole.BASIC):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -199,8 +179,8 @@ def update_api_key(
|
||||
api_key_id=existing_api_key.id,
|
||||
api_key_display=existing_api_key.api_key_display,
|
||||
api_key_name=api_key_args.name,
|
||||
api_key_role=api_key_user.role,
|
||||
user_id=existing_api_key.user_id,
|
||||
groups=_get_user_groups(db_session, existing_api_key.user_id),
|
||||
)
|
||||
|
||||
|
||||
@@ -229,8 +209,8 @@ def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor
|
||||
api_key_display=existing_api_key.api_key_display,
|
||||
api_key=new_api_key,
|
||||
api_key_name=existing_api_key.name,
|
||||
api_key_role=api_key_user.role,
|
||||
user_id=existing_api_key.user_id,
|
||||
groups=_get_user_groups(db_session, existing_api_key.user_id),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.models import ID
|
||||
from fastapi_users.models import UP
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
|
||||
from sqlalchemy import func
|
||||
@@ -12,13 +15,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import get_api_key_email_pattern
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
@@ -60,14 +62,10 @@ def _add_live_user_count_where_clause(
|
||||
select_stmt = select_stmt.where(User.email != NO_AUTH_PLACEHOLDER_USER_EMAIL) # type: ignore
|
||||
|
||||
if only_admin_users:
|
||||
return select_stmt.where(
|
||||
User.effective_permissions.contains(
|
||||
[Permission.FULL_ADMIN_PANEL_ACCESS.value]
|
||||
)
|
||||
)
|
||||
return select_stmt.where(User.role == UserRole.ADMIN)
|
||||
|
||||
return select_stmt.where(
|
||||
User.account_type != AccountType.EXT_PERM_USER,
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -97,10 +95,24 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
return user_count
|
||||
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
else:
|
||||
create_dict["role"] = UserRole.BASIC
|
||||
return await super().create(create_dict)
|
||||
|
||||
|
||||
async def get_user_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount)
|
||||
|
||||
|
||||
async def get_access_token_db(
|
||||
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -22,7 +21,6 @@ from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -33,6 +31,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -50,27 +49,48 @@ class ConnectorType(str, Enum):
|
||||
def _add_user_filters(
|
||||
stmt: Select[tuple[*R]], user: User, get_editable: bool = True
|
||||
) -> Select[tuple[*R]]:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.MANAGE_CONNECTORS in effective:
|
||||
if user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
# If anonymous user, only show public cc_pairs
|
||||
if user.is_anonymous:
|
||||
return stmt.where(ConnectorCredentialPair.access_type == AccessType.PUBLIC)
|
||||
where_clause = ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
return stmt.where(where_clause)
|
||||
|
||||
stmt = stmt.distinct()
|
||||
UG__CCpair = aliased(UserGroup__ConnectorCredentialPair)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
|
||||
"""
|
||||
Here we select cc_pairs by relation:
|
||||
User -> User__UserGroup -> UserGroup__ConnectorCredentialPair ->
|
||||
ConnectorCredentialPair
|
||||
"""
|
||||
stmt = stmt.outerjoin(UG__CCpair).outerjoin(
|
||||
User__UG,
|
||||
User__UG.user_group_id == UG__CCpair.user_group_id,
|
||||
)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
"""
|
||||
Filter cc_pairs by:
|
||||
- if the user is in the user_group that owns the cc_pair
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out cc_pairs that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all cc_pairs in the groups the user is a curator
|
||||
for (as well as public cc_pairs)
|
||||
"""
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(UG__CCpair.cc_pair_id == ConnectorCredentialPair.id)
|
||||
@@ -513,6 +533,7 @@ def add_credential_to_connector(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
@@ -595,6 +616,7 @@ def remove_credential_from_connector(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
@@ -7,18 +8,18 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import and_
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -42,14 +43,16 @@ PUBLIC_CREDENTIAL_ID = 0
|
||||
def _add_user_filters(
|
||||
stmt: Select,
|
||||
user: User,
|
||||
get_editable: bool = True,
|
||||
) -> Select:
|
||||
"""Attaches filters to ensure the user can only access appropriate credentials."""
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
if user.is_anonymous:
|
||||
raise ValueError("Anonymous users are not allowed to access credentials")
|
||||
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.MANAGE_CONNECTORS in effective:
|
||||
if user.role == UserRole.ADMIN:
|
||||
# Admins can access all credentials that are public or owned by them
|
||||
# or are not associated with any user
|
||||
return stmt.where(
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
@@ -58,9 +61,56 @@ def _add_user_filters(
|
||||
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
|
||||
)
|
||||
)
|
||||
if user.role == UserRole.BASIC:
|
||||
# Basic users can only access credentials that are owned by them
|
||||
return stmt.where(Credential.user_id == user.id)
|
||||
|
||||
# All other users: only their own credentials
|
||||
return stmt.where(Credential.user_id == user.id)
|
||||
stmt = stmt.distinct()
|
||||
"""
|
||||
THIS PART IS FOR CURATORS AND GLOBAL CURATORS
|
||||
Here we select cc_pairs by relation:
|
||||
User -> User__UserGroup -> Credential__UserGroup -> Credential
|
||||
"""
|
||||
stmt = stmt.outerjoin(Credential__UserGroup).outerjoin(
|
||||
User__UserGroup,
|
||||
User__UserGroup.user_group_id == Credential__UserGroup.user_group_id,
|
||||
)
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is a curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all Credentials in the groups the user is a curator
|
||||
for (as well as public Credentials)
|
||||
- if we are not editing, we return all Credentials directly connected to the user
|
||||
"""
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(Credential__UserGroup.credential_id == Credential.id)
|
||||
.where(~Credential__UserGroup.user_group_id.in_(user_groups))
|
||||
.correlate(Credential)
|
||||
)
|
||||
else:
|
||||
where_clause |= Credential.curator_public == True # noqa: E712
|
||||
where_clause |= Credential.user_id == user.id # noqa: E712
|
||||
|
||||
where_clause |= Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE)
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def _relate_credential_to_user_groups__no_commit(
|
||||
@@ -82,9 +132,10 @@ def _relate_credential_to_user_groups__no_commit(
|
||||
def fetch_credentials_for_user(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
get_editable: bool = True,
|
||||
) -> list[Credential]:
|
||||
stmt = select(Credential)
|
||||
stmt = _add_user_filters(stmt, user)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
@@ -93,12 +144,14 @@ def fetch_credential_by_id_for_user(
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
@@ -120,9 +173,10 @@ def fetch_credentials_by_source_for_user(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
document_source: DocumentSource | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> list[Credential]:
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
base_query = _add_user_filters(base_query, user)
|
||||
base_query = _add_user_filters(base_query, user, get_editable=get_editable)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME
|
||||
from onyx.db.api_key import insert_api_key
|
||||
from onyx.db.models import ApiKey
|
||||
@@ -111,6 +112,7 @@ def get_or_create_discord_service_api_key(
|
||||
logger.info(f"Creating Discord service API key for tenant {tenant_id}")
|
||||
api_key_args = APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Limited role is sufficient for chat requests
|
||||
)
|
||||
api_key_descriptor = insert_api_key(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import false as sa_false
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
@@ -13,21 +13,22 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.federated import create_federated_connector_document_set_mapping
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet as DocumentSetDBModel
|
||||
from onyx.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from onyx.server.features.document_set.models import DocumentSetUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -37,16 +38,54 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select:
|
||||
# MANAGE → always return all
|
||||
if has_permission(user, Permission.MANAGE_DOCUMENT_SETS):
|
||||
if user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
# READ → return all when reading, nothing when editing
|
||||
if has_permission(user, Permission.READ_DOCUMENT_SETS):
|
||||
if get_editable:
|
||||
return stmt.where(sa_false())
|
||||
return stmt
|
||||
# No permission → return nothing
|
||||
return stmt.where(sa_false())
|
||||
|
||||
stmt = stmt.distinct()
|
||||
DocumentSet__UG = aliased(DocumentSet__UserGroup)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
"""
|
||||
Here we select cc_pairs by relation:
|
||||
User -> User__UserGroup -> DocumentSet__UserGroup -> DocumentSet
|
||||
"""
|
||||
stmt = stmt.outerjoin(DocumentSet__UG).outerjoin(
|
||||
User__UserGroup,
|
||||
User__UserGroup.user_group_id == DocumentSet__UG.user_group_id,
|
||||
)
|
||||
"""
|
||||
Filter DocumentSets by:
|
||||
- if the user is in the user_group that owns the DocumentSet
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out DocumentSets that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all DocumentSets in the groups the user is a curator
|
||||
for (as well as public DocumentSets)
|
||||
"""
|
||||
|
||||
# Anonymous users only see public DocumentSets
|
||||
if user.is_anonymous:
|
||||
where_clause = DocumentSetDBModel.is_public == True # noqa: E712
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(DocumentSet__UG.document_set_id == DocumentSetDBModel.id)
|
||||
.where(~DocumentSet__UG.user_group_id.in_(user_groups))
|
||||
.correlate(DocumentSetDBModel)
|
||||
)
|
||||
where_clause |= DocumentSetDBModel.user_id == user.id
|
||||
else:
|
||||
where_clause |= DocumentSetDBModel.is_public == True # noqa: E712
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def _delete_document_set_cc_pairs__no_commit(
|
||||
@@ -296,6 +335,7 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -346,6 +347,25 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -358,8 +378,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -367,8 +390,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -366,12 +366,12 @@ class Permission(str, PyEnum):
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
READ_USER_GROUPS = "read:user_groups"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
@@ -381,8 +381,8 @@ class Permission(str, PyEnum):
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
MANAGE_SERVICE_ACCOUNT_API_KEYS = "manage:service_account_api_keys"
|
||||
MANAGE_BOTS = "manage:bots"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
@@ -13,12 +13,10 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document as DbDocument
|
||||
@@ -27,6 +25,7 @@ from onyx.db.models import DocumentRetrievalFeedback
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +43,7 @@ def _fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument:
|
||||
|
||||
|
||||
def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select:
|
||||
if has_permission(user, Permission.FULL_ADMIN_PANEL_ACCESS):
|
||||
if user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
stmt = stmt.distinct()
|
||||
@@ -72,11 +71,14 @@ def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Se
|
||||
)
|
||||
|
||||
"""
|
||||
Filter Documents by group membership:
|
||||
- if get_editable, the document's CCPair must be owned exclusively by
|
||||
groups the user belongs to (prevents mutating docs that are also
|
||||
visible to groups outside the user's reach)
|
||||
- otherwise, show docs in any group the user belongs to plus public docs
|
||||
Filter Documents by:
|
||||
- if the user is in the user_group that owns the object
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out objects that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all objects in the groups the user is a curator
|
||||
for (as well as public objects as well)
|
||||
"""
|
||||
|
||||
# Anonymous users only see public documents
|
||||
@@ -85,6 +87,8 @@ def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Se
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
where_clause &= (
|
||||
|
||||
@@ -899,6 +899,7 @@ def create_index_attempt_error(
|
||||
failure: ConnectorFailure,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
exc = failure.exception
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
@@ -921,6 +922,7 @@ def create_index_attempt_error(
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
error_type=type(exc).__name__ if exc else None,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -83,47 +84,51 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
|
||||
@@ -302,11 +302,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
)
|
||||
# Legacy tombstone column: no longer read or written by application code.
|
||||
# Kept nullable so a pure-code rollback keeps working.
|
||||
role: Mapped[UserRole | None] = mapped_column(
|
||||
Enum(UserRole, native_enum=False),
|
||||
nullable=True,
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
@@ -2425,6 +2422,8 @@ class IndexAttemptError(Base):
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
|
||||
@@ -9,9 +9,8 @@ from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import Notification
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -77,9 +76,7 @@ def get_notification_by_id(
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id and not (
|
||||
notif.user_id is None
|
||||
and user is not None
|
||||
and has_permission(user, Permission.FULL_ADMIN_PANEL_ACCESS)
|
||||
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
|
||||
):
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
|
||||
@@ -16,12 +16,12 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentSet
|
||||
@@ -74,9 +74,7 @@ class PersonaLoadType(Enum):
|
||||
def _add_user_filters(
|
||||
stmt: Select[tuple[Persona]], user: User, get_editable: bool = True
|
||||
) -> Select[tuple[Persona]]:
|
||||
if has_permission(user, Permission.MANAGE_AGENTS):
|
||||
return stmt
|
||||
if not get_editable and has_permission(user, Permission.READ_AGENTS):
|
||||
if user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
stmt = stmt.distinct()
|
||||
@@ -100,7 +98,12 @@ def _add_user_filters(
|
||||
"""
|
||||
Filter Personas by:
|
||||
- if the user is in the user_group that owns the Persona
|
||||
- if we are not editing, we show all public and listed Personas
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Personas that are owned by groups
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all Personas in the groups the user is a curator
|
||||
for (as well as public Personas)
|
||||
- if we are not editing, we return all Personas directly connected to the user
|
||||
"""
|
||||
|
||||
@@ -109,9 +112,21 @@ def _add_user_filters(
|
||||
where_clause = Persona.is_public == True # noqa: E712
|
||||
return stmt.where(where_clause)
|
||||
|
||||
# If curator ownership restriction is enabled, curators can only access their own assistants
|
||||
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
|
||||
UserRole.CURATOR,
|
||||
UserRole.GLOBAL_CURATOR,
|
||||
]:
|
||||
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
if get_editable:
|
||||
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
||||
if user.role == UserRole.CURATOR:
|
||||
user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712
|
||||
where_clause &= (
|
||||
~exists()
|
||||
.where(Persona__UG.persona_id == Persona.id)
|
||||
@@ -182,7 +197,7 @@ def _get_persona_by_name(
|
||||
- Non-admin users: can only see their own personas
|
||||
"""
|
||||
stmt = select(Persona).where(Persona.name == persona_name)
|
||||
if user and not has_permission(user, Permission.MANAGE_AGENTS):
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Persona.user_id == user.id)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
@@ -256,10 +271,12 @@ def create_update_persona(
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.is_featured:
|
||||
if not has_permission(user, Permission.MANAGE_AGENTS):
|
||||
raise ValueError(
|
||||
"Only users with agent management permissions can make a featured persona"
|
||||
)
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -336,11 +353,7 @@ def update_persona_shared(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if (
|
||||
user
|
||||
and not has_permission(user, Permission.MANAGE_AGENTS)
|
||||
and persona.user_id != user.id
|
||||
):
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
@@ -376,10 +389,7 @@ def update_persona_public_status(
|
||||
persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
if (
|
||||
not has_permission(user, Permission.MANAGE_AGENTS)
|
||||
and persona.user_id != user.id
|
||||
):
|
||||
if user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise ValueError("You don't have permission to modify this persona")
|
||||
|
||||
persona.is_public = is_public
|
||||
@@ -1216,11 +1226,7 @@ def get_persona_by_id(
|
||||
if not include_deleted:
|
||||
persona_stmt = persona_stmt.where(Persona.deleted.is_(False))
|
||||
|
||||
if (
|
||||
not user
|
||||
or has_permission(user, Permission.MANAGE_AGENTS)
|
||||
or (not is_for_edit and has_permission(user, Permission.READ_AGENTS))
|
||||
):
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
result = db_session.execute(persona_stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
if persona is None:
|
||||
@@ -1237,6 +1243,14 @@ def get_persona_by_id(
|
||||
# if the user is in the .users of the persona
|
||||
or_conditions |= User.id == user.id
|
||||
or_conditions |= Persona.is_public == True # noqa: E712
|
||||
elif user.role == UserRole.GLOBAL_CURATOR:
|
||||
# global curators can edit personas for the groups they are in
|
||||
or_conditions |= User__UserGroup.user_id == user.id
|
||||
elif user.role == UserRole.CURATOR:
|
||||
# curators can edit personas for the groups they are curators of
|
||||
or_conditions |= (User__UserGroup.user_id == user.id) & (
|
||||
User__UserGroup.is_curator == True # noqa: E712
|
||||
)
|
||||
|
||||
persona_stmt = persona_stmt.where(or_conditions)
|
||||
result = db_session.execute(persona_stmt)
|
||||
|
||||
@@ -8,15 +8,19 @@ from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import is_limited_user
|
||||
from onyx.db.users import user_is_admin
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -25,6 +29,59 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group.
|
||||
# assign_user_to_default_groups__no_commit internally skips
|
||||
# ANONYMOUS, BOT, and EXT_PERM_USER account types.
|
||||
# Also skip limited users (no group assignment).
|
||||
if not is_limited_user(user):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def deactivate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
@@ -50,7 +107,7 @@ def activate_user(
|
||||
# Also skip limited users (no group assignment).
|
||||
if not is_limited_user(user):
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=user_is_admin(user)
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import func
|
||||
@@ -14,11 +15,11 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -52,15 +53,82 @@ def is_limited_user(user: User) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def user_is_admin(user: User) -> bool:
|
||||
"""Return True if the user holds the full admin permission.
|
||||
|
||||
Derived from effective_permissions, which is itself maintained from
|
||||
group membership — Admin-group members carry FULL_ADMIN_PANEL_ACCESS.
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
return Permission.FULL_ADMIN_PANEL_ACCESS.value in (
|
||||
user.effective_permissions or []
|
||||
)
|
||||
Validate that a user role update is valid.
|
||||
Assumed only admins can hit this endpoint.
|
||||
raise if:
|
||||
- requested role is a curator
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current account type is ANONYMOUS or SERVICE_ACCOUNT
|
||||
"""
|
||||
|
||||
if current_account_type == AccountType.BOT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_account_type in (AccountType.ANONYMOUS, AccountType.SERVICE_ACCOUNT):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot change the role of an anonymous or service account user.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
return
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Curator role must be set via the User Group Menu",
|
||||
)
|
||||
|
||||
if requested_role == UserRole.LIMITED:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to a Limited User role. "
|
||||
"This role is automatically assigned to users through certain endpoints in the API."
|
||||
),
|
||||
)
|
||||
|
||||
if requested_role == UserRole.SLACK_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to a Slack User role. "
|
||||
"This role is automatically assigned to users who only use Onyx via Slack."
|
||||
),
|
||||
)
|
||||
|
||||
if requested_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"A user cannot be set to an External Permissioned User role. "
|
||||
"This role is automatically assigned to users who have been "
|
||||
"pulled in to the system via an external permissions system."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_all_users(
|
||||
@@ -77,7 +145,7 @@ def get_all_users(
|
||||
stmt = stmt.where(User.email != NO_AUTH_PLACEHOLDER_USER_EMAIL) # type: ignore
|
||||
|
||||
if not include_external:
|
||||
stmt = stmt.where(User.account_type != AccountType.EXT_PERM_USER)
|
||||
stmt = stmt.where(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
|
||||
@@ -87,6 +155,7 @@ def get_all_users(
|
||||
|
||||
def _get_accepted_user_where_clause(
|
||||
email_filter_string: str | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
is_active_filter: bool | None = None,
|
||||
) -> list[ColumnElement[bool]]:
|
||||
@@ -97,6 +166,7 @@ def _get_accepted_user_where_clause(
|
||||
Parameters:
|
||||
- email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included.
|
||||
- is_active_filter: When True, only active users will be included. When False, only inactive users will be included.
|
||||
- roles_filter: A list of user roles to filter by. Only users with roles in this list will be included.
|
||||
- include_external: If False, external permissioned users will be excluded.
|
||||
|
||||
Returns:
|
||||
@@ -116,7 +186,7 @@ def _get_accepted_user_where_clause(
|
||||
]
|
||||
|
||||
if not include_external:
|
||||
where_clause.append(User.account_type != AccountType.EXT_PERM_USER)
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name
|
||||
@@ -127,6 +197,9 @@ def _get_accepted_user_where_clause(
|
||||
)
|
||||
)
|
||||
|
||||
if roles_filter:
|
||||
where_clause.append(User.role.in_(roles_filter))
|
||||
|
||||
if is_active_filter is not None:
|
||||
where_clause.append(is_active_col.is_(is_active_filter))
|
||||
|
||||
@@ -139,7 +212,7 @@ def get_all_accepted_users(
|
||||
) -> Sequence[User]:
|
||||
"""Returns all accepted users without pagination.
|
||||
Uses the same filtering as the paginated endpoint but without
|
||||
search or active filters."""
|
||||
search, role, or active filters."""
|
||||
stmt = select(User)
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
include_external=include_external,
|
||||
@@ -154,12 +227,14 @@ def get_page_of_filtered_users(
|
||||
page_num: int,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
users_stmt = select(User)
|
||||
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
@@ -175,10 +250,12 @@ def get_total_filtered_users_count(
|
||||
db_session: Session,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> int:
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
@@ -189,46 +266,39 @@ def get_total_filtered_users_count(
|
||||
return db_session.scalar(total_count_stmt) or 0
|
||||
|
||||
|
||||
def get_user_counts_by_account_type_and_status(
|
||||
def get_user_counts_by_role_and_status(
|
||||
db_session: Session,
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Returns user counts grouped by account_type and by active/inactive status.
|
||||
"""Returns user counts grouped by role and by active/inactive status.
|
||||
|
||||
Excludes API key users, anonymous users, and no-auth placeholder users.
|
||||
Uses a single query with conditional aggregation.
|
||||
"""
|
||||
base_where = _get_accepted_user_where_clause()
|
||||
account_type_col = User.__table__.c.account_type
|
||||
role_col = User.__table__.c.role
|
||||
is_active_col = User.__table__.c.is_active
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
account_type_col,
|
||||
role_col,
|
||||
func.count().label("total"),
|
||||
func.sum(case((is_active_col.is_(True), 1), else_=0)).label("active"),
|
||||
func.sum(case((is_active_col.is_(False), 1), else_=0)).label("inactive"),
|
||||
)
|
||||
.where(*base_where)
|
||||
.group_by(account_type_col)
|
||||
.group_by(role_col)
|
||||
)
|
||||
|
||||
account_type_counts: dict[str, int] = {}
|
||||
role_counts: dict[str, int] = {}
|
||||
status_counts: dict[str, int] = {"active": 0, "inactive": 0}
|
||||
|
||||
for account_type_val, total, active, inactive in db_session.execute(stmt).all():
|
||||
key = (
|
||||
account_type_val.value
|
||||
if hasattr(account_type_val, "value")
|
||||
else str(account_type_val)
|
||||
)
|
||||
account_type_counts[key] = total
|
||||
for role_val, total, active, inactive in db_session.execute(stmt).all():
|
||||
key = role_val.value if hasattr(role_val, "value") else str(role_val)
|
||||
role_counts[key] = total
|
||||
status_counts["active"] += active or 0
|
||||
status_counts["inactive"] += inactive or 0
|
||||
|
||||
return {
|
||||
"account_type_counts": account_type_counts,
|
||||
"status_counts": status_counts,
|
||||
}
|
||||
return {"role_counts": role_counts, "status_counts": status_counts}
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
@@ -251,6 +321,7 @@ def _generate_slack_user(email: str) -> User:
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
@@ -261,6 +332,7 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
@@ -297,6 +369,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -22,6 +20,7 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -184,6 +183,14 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -193,7 +200,6 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -717,6 +723,7 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -737,10 +744,7 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
tool_id=research_agent_tool_id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -56,7 +56,6 @@ class OnyxErrorCode(Enum):
|
||||
DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404)
|
||||
SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404)
|
||||
USER_NOT_FOUND = ("USER_NOT_FOUND", 404)
|
||||
DOCUMENT_SET_NOT_FOUND = ("DOCUMENT_SET_NOT_FOUND", 404)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Conflict (409)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
@@ -16,16 +14,14 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking import DocumentChunker
|
||||
from onyx.indexing.chunking import extract_blurb
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
@@ -154,9 +150,6 @@ class Chunker:
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
# Create a token counter function that returns the count instead of the tokens
|
||||
def token_counter(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
@@ -186,234 +179,12 @@ class Chunker:
|
||||
else None
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
"""
|
||||
Splits the text into smaller chunks based on token count to ensure
|
||||
no chunk exceeds the content_token_limit.
|
||||
"""
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
def _extract_blurb(self, text: str) -> str:
|
||||
"""
|
||||
Extract a short blurb from the text (first chunk of size `blurb_size`).
|
||||
"""
|
||||
# chunker is in `text` mode
|
||||
texts = cast(list[str], self.blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
# chunker is in `text` mode
|
||||
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
|
||||
return None
|
||||
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
image_file_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to create a new DocAwareChunk, append it to chunks_list.
|
||||
"""
|
||||
new_chunk = DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks_list),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
self._document_chunker = DocumentChunker(
|
||||
tokenizer=tokenizer,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
chunk_splitter=self.chunk_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
Works with processed sections that are base Section objects.
|
||||
"""
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link_text = section.link or ""
|
||||
image_url = section.image_file_id
|
||||
|
||||
# If there is no useful content, skip
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this section has an image, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
is_continuation=False,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image section
|
||||
# (Using the text summary that was generated during processing)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text} if section_link_text else {},
|
||||
image_file_id=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
# Continue to next section
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# chunker is in `text` mode
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
small_chunk,
|
||||
{0: section_link_text},
|
||||
is_continuation=(j != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
else:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
split_text,
|
||||
{0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
link_offsets[current_offset] = section_link_text
|
||||
else:
|
||||
# finalize the existing chunk
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
# start a new chunk
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# finalize any leftover text chunk
|
||||
if chunk_text.strip() or not chunks:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets or {0: ""}, # safe default
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(
|
||||
self, document: IndexingDocument
|
||||
@@ -423,7 +194,10 @@ class Chunker:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title = extract_blurb(
|
||||
document.get_title_for_document_index() or "",
|
||||
self.blurb_splitter,
|
||||
)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
|
||||
@@ -491,7 +265,7 @@ class Chunker:
|
||||
# Use processed_sections if available (IndexingDocument), otherwise use original sections
|
||||
sections_to_chunk = document.processed_sections
|
||||
|
||||
normal_chunks = self._chunk_document_with_sections(
|
||||
normal_chunks = self._document_chunker.chunk(
|
||||
document,
|
||||
sections_to_chunk,
|
||||
title_prefix,
|
||||
|
||||
7
backend/onyx/indexing/chunking/__init__.py
Normal file
7
backend/onyx/indexing/chunking/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from onyx.indexing.chunking.document_chunker import DocumentChunker
|
||||
from onyx.indexing.chunking.section_chunker import extract_blurb
|
||||
|
||||
__all__ = [
|
||||
"DocumentChunker",
|
||||
"extract_blurb",
|
||||
]
|
||||
109
backend/onyx/indexing/chunking/document_chunker.py
Normal file
109
backend/onyx/indexing/chunking/document_chunker.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.indexing.chunking.image_section_chunker import ImageChunker
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.text_section_chunker import TextChunker
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentChunker:
|
||||
"""Converts a document's processed sections into DocAwareChunks.
|
||||
|
||||
Drop-in replacement for `Chunker._chunk_document_with_sections`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
blurb_splitter: SentenceChunker,
|
||||
chunk_splitter: SentenceChunker,
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> None:
|
||||
self.blurb_splitter = blurb_splitter
|
||||
self.mini_chunk_splitter = mini_chunk_splitter
|
||||
|
||||
self._dispatch: dict[SectionType, SectionChunker] = {
|
||||
SectionType.TEXT: TextChunker(
|
||||
tokenizer=tokenizer,
|
||||
chunk_splitter=chunk_splitter,
|
||||
),
|
||||
SectionType.IMAGE: ImageChunker(),
|
||||
}
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
payloads = self._collect_section_payloads(
|
||||
document=document,
|
||||
sections=sections,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
if not payloads:
|
||||
payloads.append(ChunkPayload(text="", links={0: ""}))
|
||||
|
||||
return [
|
||||
payload.to_doc_aware_chunk(
|
||||
document=document,
|
||||
chunk_id=idx,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for idx, payload in enumerate(payloads)
|
||||
]
|
||||
|
||||
def _collect_section_payloads(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
content_token_limit: int,
|
||||
) -> list[ChunkPayload]:
|
||||
accumulator = AccumulatorState()
|
||||
payloads: list[ChunkPayload] = []
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc "
|
||||
f"{document.semantic_identifier}, link={section.link}"
|
||||
)
|
||||
continue
|
||||
|
||||
chunker = self._select_chunker(section)
|
||||
result = chunker.chunk_section(
|
||||
section=section,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
payloads.extend(result.payloads)
|
||||
accumulator = result.accumulator
|
||||
|
||||
payloads.extend(accumulator.flush_to_list())
|
||||
return payloads
|
||||
|
||||
def _select_chunker(self, section: Section) -> SectionChunker:
|
||||
try:
|
||||
return self._dispatch[section.type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No SectionChunker registered for type={section.type}")
|
||||
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
35
backend/onyx/indexing/chunking/image_section_chunker.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
|
||||
class ImageChunker(SectionChunker):
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int, # noqa: ARG002
|
||||
) -> SectionChunkerOutput:
|
||||
assert section.image_file_id is not None
|
||||
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
|
||||
# Flush any partially built text chunks
|
||||
payloads = accumulator.flush_to_list()
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=section_text,
|
||||
links={0: section_link} if section_link else {},
|
||||
image_file_id=section.image_file_id,
|
||||
is_continuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
100
backend/onyx/indexing/chunking/section_chunker.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
|
||||
texts = cast(list[str], blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def get_mini_chunk_texts(
|
||||
chunk_text: str,
|
||||
mini_chunk_splitter: SentenceChunker | None,
|
||||
) -> list[str] | None:
|
||||
if mini_chunk_splitter and chunk_text.strip():
|
||||
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
|
||||
return None
|
||||
|
||||
|
||||
class ChunkPayload(BaseModel):
|
||||
"""Section-local chunk content without document-scoped fields.
|
||||
|
||||
The orchestrator upgrades these to DocAwareChunks via
|
||||
`to_doc_aware_chunk` after assigning chunk_ids and attaching
|
||||
title/metadata.
|
||||
"""
|
||||
|
||||
text: str
|
||||
links: dict[int, str]
|
||||
is_continuation: bool = False
|
||||
image_file_id: str | None = None
|
||||
|
||||
def to_doc_aware_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunk_id: int,
|
||||
blurb_splitter: SentenceChunker,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=chunk_id,
|
||||
blurb=extract_blurb(self.text, blurb_splitter),
|
||||
content=self.text,
|
||||
source_links=self.links or {0: ""},
|
||||
image_file_id=self.image_file_id,
|
||||
section_continuation=self.is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class AccumulatorState(BaseModel):
|
||||
"""Cross-section text buffer threaded through SectionChunkers."""
|
||||
|
||||
text: str = ""
|
||||
link_offsets: dict[int, str] = Field(default_factory=dict)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.text.strip()
|
||||
|
||||
def flush_to_list(self) -> list[ChunkPayload]:
|
||||
if self.is_empty():
|
||||
return []
|
||||
return [ChunkPayload(text=self.text, links=self.link_offsets)]
|
||||
|
||||
|
||||
class SectionChunkerOutput(BaseModel):
|
||||
payloads: list[ChunkPayload]
|
||||
accumulator: AccumulatorState
|
||||
|
||||
|
||||
class SectionChunker(ABC):
|
||||
@abstractmethod
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput: ...
|
||||
129
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
129
backend/onyx/indexing/chunking/text_section_chunker.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TextChunker(SectionChunker):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_splitter: SentenceChunker,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.chunk_splitter = chunk_splitter
|
||||
|
||||
self.section_separator_token_count = count_tokens(
|
||||
SECTION_SEPARATOR,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# Oversized — flush buffer and split the section
|
||||
if section_token_count > content_token_limit:
|
||||
return self._handle_oversized_section(
|
||||
section_text=section_text,
|
||||
section_link=section_link,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
current_token_count = count_tokens(accumulator.text, self.tokenizer)
|
||||
next_section_tokens = self.section_separator_token_count + section_token_count
|
||||
|
||||
# Fits — extend the accumulator
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
offset = len(shared_precompare_cleanup(accumulator.text))
|
||||
new_text = accumulator.text
|
||||
if new_text:
|
||||
new_text += SECTION_SEPARATOR
|
||||
new_text += section_text
|
||||
return SectionChunkerOutput(
|
||||
payloads=[],
|
||||
accumulator=AccumulatorState(
|
||||
text=new_text,
|
||||
link_offsets={**accumulator.link_offsets, offset: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
# Doesn't fit — flush buffer and restart with this section
|
||||
return SectionChunkerOutput(
|
||||
payloads=accumulator.flush_to_list(),
|
||||
accumulator=AccumulatorState(
|
||||
text=section_text,
|
||||
link_offsets={0: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_oversized_section(
|
||||
self,
|
||||
section_text: str,
|
||||
section_link: str,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and count_tokens(split_text, self.tokenizer) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=small_chunk,
|
||||
links={0: section_link},
|
||||
is_continuation=(j != 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=split_text,
|
||||
links={0: section_link},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
@@ -542,6 +542,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
**document.model_dump(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
type=section.type,
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
link=section.link,
|
||||
image_file_id=(
|
||||
@@ -566,6 +567,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved - ensure text is always a string
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
link=section.link,
|
||||
image_file_id=section.image_file_id,
|
||||
text="", # Initialize with empty string
|
||||
@@ -609,6 +611,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -66,7 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI-Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -87,6 +87,44 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"gemini": "Gemini",
|
||||
"stability": "Stability",
|
||||
"writer": "Writer",
|
||||
# Custom provider display names (used in the custom provider picker)
|
||||
"aiml": "AI/ML",
|
||||
"assemblyai": "AssemblyAI",
|
||||
"aws_polly": "AWS Polly",
|
||||
"azure_ai": "Azure AI",
|
||||
"chatgpt": "ChatGPT",
|
||||
"cohere_chat": "Cohere Chat",
|
||||
"datarobot": "DataRobot",
|
||||
"deepgram": "Deepgram",
|
||||
"deepinfra": "DeepInfra",
|
||||
"elevenlabs": "ElevenLabs",
|
||||
"fal_ai": "fal.ai",
|
||||
"featherless_ai": "Featherless AI",
|
||||
"fireworks_ai": "Fireworks AI",
|
||||
"friendliai": "FriendliAI",
|
||||
"gigachat": "GigaChat",
|
||||
"github_copilot": "GitHub Copilot",
|
||||
"gradient_ai": "Gradient AI",
|
||||
"huggingface": "HuggingFace",
|
||||
"jina_ai": "Jina AI",
|
||||
"lambda_ai": "Lambda AI",
|
||||
"llamagate": "LlamaGate",
|
||||
"meta_llama": "Meta Llama",
|
||||
"minimax": "MiniMax",
|
||||
"nlp_cloud": "NLP Cloud",
|
||||
"nvidia_nim": "NVIDIA NIM",
|
||||
"oci": "OCI",
|
||||
"ovhcloud": "OVHcloud",
|
||||
"palm": "PaLM",
|
||||
"publicai": "PublicAI",
|
||||
"runwayml": "RunwayML",
|
||||
"sambanova": "SambaNova",
|
||||
"together_ai": "Together AI",
|
||||
"vercel_ai_gateway": "Vercel AI Gateway",
|
||||
"volcengine": "Volcengine",
|
||||
"wandb": "W&B",
|
||||
"watsonx": "IBM watsonx",
|
||||
"zai": "ZAI",
|
||||
}
|
||||
|
||||
# Map vendors to their brand names (used for provider_display_name generation)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from onyx.auth.permissions import has_permission
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
@@ -112,10 +111,7 @@ def get_llm_for_persona(
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
|
||||
if not can_user_access_llm_provider(
|
||||
provider_model,
|
||||
user_group_ids,
|
||||
persona,
|
||||
has_permission(user, Permission.FULL_ADMIN_PANEL_ACCESS),
|
||||
provider_model, user_group_ids, persona, user.role == UserRole.ADMIN
|
||||
):
|
||||
logger.warning(
|
||||
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
|
||||
|
||||
@@ -338,7 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI-Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -96,6 +96,9 @@ from onyx.server.features.persona.api import admin_router as admin_persona_route
|
||||
from onyx.server.features.persona.api import agents_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.projects.api import router as projects_router
|
||||
from onyx.server.features.proposal_review.api.api import (
|
||||
router as proposal_review_router,
|
||||
)
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
from onyx.server.features.tool.api import router as tool_router
|
||||
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
|
||||
@@ -469,6 +472,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, projects_router)
|
||||
include_router_with_global_prefix_prepended(application, public_build_router)
|
||||
include_router_with_global_prefix_prepended(application, build_router)
|
||||
include_router_with_global_prefix_prepended(application, proposal_review_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, hierarchy_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
|
||||
@@ -90,6 +90,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.manage.models import SlackBotTokens
|
||||
from onyx.tracing.setup import setup_tracing
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -1206,6 +1207,7 @@ if __name__ == "__main__":
|
||||
tenant_handler = SlackbotHandler()
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
setup_tracing()
|
||||
|
||||
try:
|
||||
# Keep the main thread alive
|
||||
|
||||
@@ -20,7 +20,7 @@ router = APIRouter(prefix="/admin/api-key")
|
||||
|
||||
@router.get("")
|
||||
def list_api_keys(
|
||||
_: User = Depends(require_permission(Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ApiKeyDescriptor]:
|
||||
return fetch_api_keys(db_session)
|
||||
@@ -29,9 +29,7 @@ def list_api_keys(
|
||||
@router.post("")
|
||||
def create_api_key(
|
||||
api_key_args: APIKeyArgs,
|
||||
user: User = Depends(
|
||||
require_permission(Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS)
|
||||
),
|
||||
user: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return insert_api_key(db_session, api_key_args, user.id)
|
||||
@@ -40,7 +38,7 @@ def create_api_key(
|
||||
@router.post("/{api_key_id}/regenerate")
|
||||
def regenerate_existing_api_key(
|
||||
api_key_id: int,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return regenerate_api_key(db_session, api_key_id)
|
||||
@@ -50,7 +48,7 @@ def regenerate_existing_api_key(
|
||||
def update_existing_api_key(
|
||||
api_key_id: int,
|
||||
api_key_args: APIKeyArgs,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return update_api_key(db_session, api_key_id, api_key_args)
|
||||
@@ -59,7 +57,7 @@ def update_existing_api_key(
|
||||
@router.delete("/{api_key_id}")
|
||||
def delete_api_key(
|
||||
api_key_id: int,
|
||||
_: User = Depends(require_permission(Permission.MANAGE_SERVICE_ACCOUNT_API_KEYS)),
|
||||
_: User = Depends(require_permission(Permission.FULL_ADMIN_PANEL_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
remove_api_key(db_session, api_key_id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user