Compare commits

..

23 Commits

Author SHA1 Message Date
Alex Kim
f87e03b194 Add Datadog admission opt-out label to sandbox pods (#10040) 2026-04-14 14:00:32 -07:00
github-actions[bot]
873636a095 fix(chat): speed up text gen (#10186) to release v3.2 (#10187)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-14 13:43:15 -07:00
Justin Tahara
efb194e067 fix(llm): Fix the Auto Fetch workflow (#10181) 2026-04-14 11:16:30 -07:00
github-actions[bot]
3f7dfa7813 feat(notifications): announce upcoming group-based permissions migration (#10178) to release v3.2 (#10180)
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
2026-04-14 22:26:29 +05:30
Wenxi
5f08af3678 fix(google): handle JSON credential payloads in KV storage (@jack-larch) (#10160)
Co-authored-by: Jack Larch <jack.larch@biograph.com>
2026-04-13 18:35:51 -07:00
Nikolas Garza
1243af4f86 chore(hotfix): cherry-pick 2 commits to release v3.2 (#10140)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-04-13 14:12:33 -07:00
Nikolas Garza
91e84b8278 feat(chat): smooth character-level streaming (#10093) to release v3.2 (#10138) 2026-04-13 14:12:20 -07:00
Nikolas Garza
1d6baf10db feat(chat): scrollable tables with overflow fade (#10097) to release v3.2 (#10136) 2026-04-13 14:05:16 -07:00
github-actions[bot]
8d26357197 fix(chat): disable Deep Research in multi-model mode (ENG-4009) (#10126) to release v3.2 (#10139)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 14:04:36 -07:00
github-actions[bot]
cd43345415 fix: welcome message alignment in chrome extension/desktop (#10094) to release v3.2 (#10135)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 13:04:28 -07:00
github-actions[bot]
f99cf2f1b0 fix(chat): isolate multi-model streaming errors to their panels (#10113) to release v3.2 (#10127)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-13 12:49:20 -07:00
Jamison Lahman
7332adb1e6 fix(copy-button): fall back when Clipboard API unavailable (#10080) 2026-04-10 22:49:56 -07:00
Nikolas Garza
0ab1b76765 Revert "feat(chat): smooth character-level streaming (#10076) to release v3.2" (#10082) 2026-04-10 20:49:39 -07:00
github-actions[bot]
40cd0a78a3 feat(chat): smooth character-level streaming (#10076) to release v3.2 (#10081)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 20:41:49 -07:00
github-actions[bot]
28d8c5de46 fix(chat): model selection + multi-model follow-up correctness (#10075) to release v3.2 (#10078) 2026-04-10 17:25:00 -07:00
github-actions[bot]
004092767f fix(mcp): prevent masked OAuth credentials from being stored on re-auth (#10066) to release v3.2 (#10069)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-10 14:47:17 -07:00
Nikolas Garza
eb4689a669 fix(chat): hide ModelSelector in search mode (#10052) to release v3.2 (#10068) 2026-04-10 12:43:05 -07:00
github-actions[bot]
47dd8973c1 fix(scim): add advisory lock to prevent seat limit race condition (#10048) to release v3.2 (#10065)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:05:14 -07:00
github-actions[bot]
a1403ef78c feat(slack-bot): make agent selector searchable (#10036) to release v3.2 (#10038)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:04:51 -07:00
github-actions[bot]
f96b9d6804 fix(license): exclude service account users from seat count (#10053) to release v3.2 (#10061)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-10 12:04:30 -07:00
github-actions[bot]
711651276c fix(LLM config): resolve API Key before fetching models (#10056) to release v3.2 (#10057)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-04-10 00:02:33 -07:00
github-actions[bot]
3731110cf9 feat(federated): full thread replies + direct URL fetch in Slack search (#9940) to release v3.2 (#10050)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2026-04-09 18:24:02 -07:00
Evan Lohn
8fb7a8718e fix: jira bulk issue fetch batching (#10044) 2026-04-09 20:50:41 -04:00
329 changed files with 2889 additions and 23603 deletions

View File

@@ -1,62 +0,0 @@
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
fd-find \
fzf \
git \
jq \
less \
make \
neovim \
openssh-client \
python3-venv \
ripgrep \
sudo \
ca-certificates \
iptables \
ipset \
iproute2 \
dnsutils \
unzip \
wget \
zsh \
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs \
&& install -m 0755 -d /etc/apt/keyrings \
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
&& apt-get update \
&& apt-get install -y --no-install-recommends gh \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
# Install uv (Python package manager)
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
# Create non-root dev user with passwordless sudo
RUN useradd -m -s /bin/zsh dev && \
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
chmod 0440 /etc/sudoers.d/dev
ENV DEVCONTAINER=true
RUN mkdir -p /workspace && \
chown -R dev:dev /workspace
WORKDIR /workspace
# Install Claude Code
ARG CLAUDE_CODE_VERSION=latest
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
# Configure zsh — source the repo-local zshrc so shell customization
# doesn't require an image rebuild.
RUN chsh -s /bin/zsh root && \
for rc in /root/.zshrc /home/dev/.zshrc; do \
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
done && \
chown dev:dev /home/dev/.zshrc

View File

@@ -1,86 +0,0 @@
# Onyx Dev Container
A containerized development environment for working on Onyx.
## What's included
- Ubuntu 26.04 base image
- Node.js 20, uv, Claude Code
- GitHub CLI (`gh`)
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
- Zsh as default shell (sources host `~/.zshrc` if available)
- Python venv auto-activation
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
## Usage
### CLI (`ods dev`)
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
for all devcontainer operations (also available as `ods dc`):
```bash
# Start the container
ods dev up
# Open a shell
ods dev into
# Run a command
ods dev exec npm test
# Stop the container
ods dev stop
```
## Restarting the container
```bash
# Restart the container
ods dev restart
# Pull the latest published image and recreate
ods dev rebuild
```
## Image
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
The tag is pinned in `devcontainer.json` — no local build is required.
To build the image locally (e.g. while iterating on the Dockerfile):
```bash
docker buildx bake devcontainer
```
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
## User & permissions
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
An init script (`init-dev-user.sh`) runs at container start to ensure the active
user has read/write access to the bind-mounted workspace:
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
so file permissions work seamlessly.
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
maps back to your host user via the user namespace. New files are owned by your
host UID and no ACL workarounds are needed.
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
`ods dev up`.
## Firewall
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
- npm registry
- GitHub
- Anthropic API
- Sentry
- VS Code update servers
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.

View File

@@ -1,23 +0,0 @@
{
"name": "Onyx Dev Sandbox",
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
"mounts": [
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"updateRemoteUserUID": false,
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"workspaceFolder": "/workspace",
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
"waitFor": "postStartCommand"
}

View File

@@ -1,107 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
# Remap the dev user's UID/GID to match the workspace owner so that
# bind-mounted files are accessible without running as root.
#
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
# We remap dev to that UID -- fast and seamless.
#
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
# container due to user-namespace mapping. Requires
# DEVCONTAINER_REMOTE_USER=root (set automatically by
# ods dev up). Container root IS the host user, so
# bind-mounts and named volumes are symlinked into /root.
WORKSPACE=/workspace
TARGET_USER=dev
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
WS_UID=$(stat -c '%u' "$WORKSPACE")
WS_GID=$(stat -c '%g' "$WORKSPACE")
DEV_UID=$(id -u "$TARGET_USER")
DEV_GID=$(id -g "$TARGET_USER")
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
MOUNT_HOME=/home/"$TARGET_USER"
if [ "$REMOTE_USER" = "root" ]; then
ACTIVE_HOME="/root"
else
ACTIVE_HOME="$MOUNT_HOME"
fi
# ── Phase 1: home directory setup ───────────────────────────────────
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
# When running as root, symlink bind-mounts and named volumes into /root
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
for item in .claude .cache .local; do
[ -d "$MOUNT_HOME/$item" ] || continue
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
rm -rf "$ACTIVE_HOME/$item"
fi
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
done
# Symlink files (not directories).
for file in .claude.json .gitconfig .zshrc.host; do
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
done
# Nested mount: .config/nvim
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
mkdir -p "$ACTIVE_HOME/.config"
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
rm -rf "$ACTIVE_HOME/.config/nvim"
fi
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
fi
fi
# ── Phase 2: workspace access ───────────────────────────────────────
# Root always has workspace access; Phase 1 handled home setup.
if [ "$REMOTE_USER" = "root" ]; then
exit 0
fi
# Already matching -- nothing to do.
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
exit 0
fi
if [ "$WS_UID" != "0" ]; then
# ── Standard Docker ──────────────────────────────────────────────
# Workspace is owned by a non-root UID (the host user).
# Remap dev's UID/GID to match.
if [ "$DEV_GID" != "$WS_GID" ]; then
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
fi
fi
if [ "$DEV_UID" != "$WS_UID" ]; then
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
fi
fi
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
echo "warning: failed to chown $MOUNT_HOME" >&2
fi
else
# ── Rootless Docker ──────────────────────────────────────────────
# Workspace is root-owned (UID 0) due to user-namespace mapping.
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
# which is handled above. If we reach here, the user is running as dev
# under rootless Docker without the override.
echo "error: rootless Docker detected but remoteUser is not root." >&2
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
echo " or use 'ods dev up' which sets it automatically." >&2
exit 1
fi

View File

@@ -1,105 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
echo "Setting up firewall..."
# Preserve docker dns resolution
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
# Flush all rules
iptables -t nat -F
iptables -t nat -X
iptables -t mangle -F
iptables -t mangle -X
iptables -F
iptables -X
# Restore docker dns rules
if [ -n "$DOCKER_DNS_RULES" ]; then
echo "$DOCKER_DNS_RULES" | iptables-restore -n
fi
# Create ipset for allowed destinations
ipset create allowed-domains hash:net || true
ipset flush allowed-domains
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
for ip in $GITHUB_IPS; do
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
fi
done
# Resolve allowed domains
ALLOWED_DOMAINS=(
"registry.npmjs.org"
"api.anthropic.com"
"api-staging.anthropic.com"
"files.anthropic.com"
"sentry.io"
"update.code.visualstudio.com"
"pypi.org"
"files.pythonhosted.org"
"go.dev"
"storage.googleapis.com"
"static.rust-lang.org"
)
for domain in "${ALLOWED_DOMAINS[@]}"; do
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
for ip in $IPS; do
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
echo "warning: failed to add $domain ($ip) to allowlist" >&2
fi
done
done
# Allow traffic to the Docker gateway so the container can reach host services
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
if [ -n "$DOCKER_GATEWAY" ]; then
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
fi
fi
# Set default policies to DROP
iptables -P FORWARD DROP
iptables -P INPUT DROP
iptables -P OUTPUT DROP
# Allow established connections
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
# Allow loopback
iptables -A INPUT -i lo -j ACCEPT
iptables -A OUTPUT -o lo -j ACCEPT
# Allow DNS
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
# Allow outbound to allowed destinations
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
# Reject unauthorized outbound
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
# Validate firewall configuration
echo "Validating firewall configuration..."
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
for site in "${BLOCKED_SITES[@]}"; do
if timeout 2 ping -c 1 "$site" &>/dev/null; then
echo "Warning: $site is still reachable"
fi
done
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
echo "Warning: GitHub API is not accessible"
fi
echo "Firewall setup complete"

View File

@@ -1,10 +0,0 @@
# Devcontainer zshrc — sourced automatically for both root and dev users.
# Edit this file to customize the shell without rebuilding the image.
# Auto-activate Python venv
if [ -f /workspace/.venv/bin/activate ]; then
. /workspace/.venv/bin/activate
fi
# Source host zshrc if bind-mounted
[ -f ~/.zshrc.host ] && . ~/.zshrc.host

View File

@@ -44,7 +44,7 @@ jobs:
fetch-tags: true
- name: Setup uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
enable-cache: false
@@ -165,7 +165,7 @@ jobs:
fetch-depth: 0
- name: Setup uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.

View File

@@ -114,7 +114,7 @@ jobs:
ref: main
- name: Install the latest version of uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -471,7 +471,7 @@ jobs:
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
@@ -710,7 +710,7 @@ jobs:
pull-requests: write
steps:
- name: Download visual diff summaries
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
with:
pattern: screenshot-diff-summary-*
path: summaries/

View File

@@ -38,7 +38,7 @@ jobs:
- name: Install node dependencies
working-directory: ./web
run: npm ci
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
with:
prek-version: '0.3.4'
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}

View File

@@ -17,7 +17,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -1,57 +1,64 @@
{
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": ["logic", "syntax", "style"],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": false,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
"labels": [],
"comment": "",
"fixWithAI": true,
"hideFooter": false,
"strictness": 3,
"statusCheck": true,
"commentTypes": [
"logic",
"syntax",
"style"
],
"instructions": "",
"disabledLabels": [],
"excludeAuthors": [
"dependabot[bot]",
"renovate[bot]"
],
"ignoreKeywords": "",
"ignorePatterns": "",
"includeAuthors": [],
"summarySection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
"excludeBranches": [],
"fileChangeLimit": 300,
"includeBranches": [],
"includeKeywords": "",
"triggerOnUpdates": true,
"updateExistingSummaryComment": true,
"updateSummaryOnly": false,
"issuesTableSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
"statusCommentsEnabled": true,
"confidenceScoreSection": {
"included": true,
"collapsible": false
},
"sequenceDiagramSection": {
"included": true,
"collapsible": false,
"defaultOpen": false
},
"shouldUpdateDescription": false,
"rules": [
{
"scope": ["web/**"],
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
},
{
"scope": ["web/**"],
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
},
{
"scope": ["backend/**/*.py"],
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
}
]
}

View File

@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
4. **Light Worker** (`light`)
- Handles lightweight, fast operations
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
- Tasks: vespa operations, document permissions sync, external group sync
- Higher concurrency for quick tasks
5. **Heavy Worker** (`heavy`)
- Handles resource-intensive operations
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
- Primary task: document pruning operations
- Runs with 4 threads concurrency
6. **KG Processing Worker** (`kg_processing`)

View File

@@ -208,7 +208,7 @@ def do_run_migrations(
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
version_table_schema=schema,
include_schemas=True,
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column(
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
@@ -71,7 +71,7 @@ def downgrade() -> None:
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column(
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20

View File

@@ -1,499 +0,0 @@
"""add proposal review tables
Revision ID: 61ea78857c97
Revises: d129f37b3d87
Create Date: 2026-04-09 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "61ea78857c97"
down_revision = "d129f37b3d87"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
# -- proposal_review_ruleset --
op.create_table(
"proposal_review_ruleset",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("tenant_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column(
"is_default",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["created_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_ruleset_tenant_id",
"proposal_review_ruleset",
["tenant_id"],
)
# -- proposal_review_rule --
op.create_table(
"proposal_review_rule",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"ruleset_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("category", sa.Text(), nullable=True),
sa.Column("rule_type", sa.Text(), nullable=False),
sa.Column(
"rule_intent",
sa.Text(),
server_default=sa.text("'CHECK'"),
nullable=False,
),
sa.Column("prompt_template", sa.Text(), nullable=False),
sa.Column(
"source",
sa.Text(),
server_default=sa.text("'MANUAL'"),
nullable=False,
),
sa.Column("authority", sa.Text(), nullable=True),
sa.Column(
"is_hard_stop",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"priority",
sa.Integer(),
server_default=sa.text("0"),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"refinement_needed",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column("refinement_question", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["ruleset_id"],
["proposal_review_ruleset.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_rule_ruleset_id",
"proposal_review_rule",
["ruleset_id"],
)
# -- proposal_review_proposal --
# Includes inline proposal-level decision fields (no separate decision table).
op.create_table(
"proposal_review_proposal",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("document_id", sa.Text(), nullable=False),
sa.Column("tenant_id", sa.Text(), nullable=False),
sa.Column(
"status",
sa.Text(),
server_default=sa.text("'PENDING'"),
nullable=False,
),
# Inline proposal-level decision fields
sa.Column("decision_notes", sa.Text(), nullable=True),
sa.Column(
"decision_officer_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("decision_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"jira_synced",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column("jira_synced_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("document_id", "tenant_id"),
)
op.create_index(
"ix_proposal_review_proposal_tenant_id",
"proposal_review_proposal",
["tenant_id"],
)
op.create_index(
"ix_proposal_review_proposal_document_id",
"proposal_review_proposal",
["document_id"],
)
op.create_index(
"ix_proposal_review_proposal_status",
"proposal_review_proposal",
["status"],
)
# -- proposal_review_run --
op.create_table(
"proposal_review_run",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"ruleset_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"triggered_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"status",
sa.Text(),
server_default=sa.text("'PENDING'"),
nullable=False,
),
sa.Column("total_rules", sa.Integer(), nullable=False),
sa.Column(
"completed_rules",
sa.Integer(),
server_default=sa.text("0"),
nullable=False,
),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["ruleset_id"],
["proposal_review_ruleset.id"],
),
sa.ForeignKeyConstraint(["triggered_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_run_proposal_id",
"proposal_review_run",
["proposal_id"],
)
# -- proposal_review_finding --
# Includes inline per-finding decision fields (no separate decision table).
op.create_table(
"proposal_review_finding",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"rule_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column(
"review_run_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("verdict", sa.Text(), nullable=False),
sa.Column("confidence", sa.Text(), nullable=True),
sa.Column("evidence", sa.Text(), nullable=True),
sa.Column("explanation", sa.Text(), nullable=True),
sa.Column("suggested_action", sa.Text(), nullable=True),
sa.Column("llm_model", sa.Text(), nullable=True),
sa.Column("llm_tokens_used", sa.Integer(), nullable=True),
# Inline per-finding decision fields
sa.Column("decision_action", sa.Text(), nullable=True),
sa.Column("decision_notes", sa.Text(), nullable=True),
sa.Column(
"decision_officer_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("decided_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["rule_id"],
["proposal_review_rule.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["review_run_id"],
["proposal_review_run.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_finding_proposal_id",
"proposal_review_finding",
["proposal_id"],
)
op.create_index(
"ix_proposal_review_finding_review_run_id",
"proposal_review_finding",
["review_run_id"],
)
op.create_index(
"ix_proposal_review_finding_rule_id",
"proposal_review_finding",
["rule_id"],
)
# -- proposal_review_document --
op.create_table(
"proposal_review_document",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"proposal_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("file_name", sa.Text(), nullable=False),
sa.Column("file_type", sa.Text(), nullable=True),
sa.Column("file_store_id", sa.Text(), nullable=True),
sa.Column("extracted_text", sa.Text(), nullable=True),
sa.Column("document_role", sa.Text(), nullable=False),
sa.Column(
"uploaded_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["proposal_id"],
["proposal_review_proposal.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(["uploaded_by"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_document_proposal_id",
"proposal_review_document",
["proposal_id"],
)
# -- proposal_review_import_job --
op.create_table(
"proposal_review_import_job",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column(
"ruleset_id",
postgresql.UUID(as_uuid=True),
nullable=False,
),
sa.Column("tenant_id", sa.Text(), nullable=False),
sa.Column(
"status",
sa.Text(),
server_default=sa.text("'PENDING'"),
nullable=False,
),
sa.Column("source_filename", sa.Text(), nullable=False),
sa.Column("extracted_text", sa.Text(), nullable=False),
sa.Column(
"rules_created",
sa.Integer(),
server_default=sa.text("0"),
nullable=False,
),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["ruleset_id"],
["proposal_review_ruleset.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_proposal_review_import_job_ruleset_id",
"proposal_review_import_job",
["ruleset_id"],
)
# -- proposal_review_config --
op.create_table(
"proposal_review_config",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
sa.Column("tenant_id", sa.Text(), nullable=False, unique=True),
sa.Column("jira_connector_id", sa.Integer(), nullable=True),
sa.Column("jira_project_key", sa.Text(), nullable=True),
sa.Column("field_mapping", postgresql.JSONB(), nullable=True),
sa.Column("jira_writeback", postgresql.JSONB(), nullable=True),
sa.Column("review_model", sa.Text(), nullable=True),
sa.Column("import_model", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("proposal_review_import_job")
op.drop_table("proposal_review_config")
op.drop_table("proposal_review_document")
op.drop_table("proposal_review_finding")
op.drop_table("proposal_review_run")
op.drop_table("proposal_review_proposal")
op.drop_table("proposal_review_rule")
op.drop_table("proposal_review_ruleset")

View File

@@ -63,7 +63,7 @@ def upgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=False,
existing_server_default=sa.text("now()"),
existing_server_default=sa.text("now()"), # type: ignore
)
op.alter_column(
"index_attempt",
@@ -85,7 +85,7 @@ def downgrade() -> None:
"time_created",
existing_type=postgresql.TIMESTAMP(timezone=True),
nullable=True,
existing_server_default=sa.text("now()"),
existing_server_default=sa.text("now()"), # type: ignore
)
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
op.drop_table("accesstoken")

View File

@@ -19,7 +19,7 @@ depends_on: None = None
def upgrade() -> None:
sequence = Sequence("connector_credential_pair_id_seq")
op.execute(CreateSequence(sequence))
op.execute(CreateSequence(sequence)) # type: ignore
op.add_column(
"connector_credential_pair",
sa.Column(

View File

@@ -1,32 +0,0 @@
"""add failed_rules to proposal_review_run
Revision ID: ce2aa573d445
Revises: 61ea78857c97
Create Date: 2026-04-14 16:34:57.276707
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ce2aa573d445"
down_revision = "61ea78857c97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"proposal_review_run",
sa.Column(
"failed_rules",
sa.Integer(),
nullable=False,
server_default=sa.text("0"),
),
)
def downgrade() -> None:
op.drop_column("proposal_review_run", "failed_rules")

View File

@@ -1,28 +0,0 @@
"""add_error_tracking_fields_to_index_attempt_errors
Revision ID: d129f37b3d87
Revises: 503883791c39
Create Date: 2026-04-06 19:11:18.261800
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d129f37b3d87"
down_revision = "503883791c39"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt_errors",
sa.Column("error_type", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("index_attempt_errors", "error_type")

View File

@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore[arg-type]
)
with context.begin_transaction():

View File

@@ -10,7 +10,6 @@ from celery import bootsteps # type: ignore
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import before_task_publish
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
@@ -95,17 +94,6 @@ class TenantAwareTask(Task):
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@before_task_publish.connect
def on_before_task_publish(
headers: dict[str, Any] | None = None,
**kwargs: Any, # noqa: ARG001
) -> None:
"""Stamp the current wall-clock time into the task message headers so that
workers can compute queue wait time (time between publish and execution)."""
if headers is not None:
headers["enqueued_at"] = time.time()
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None, # noqa: ARG001

View File

@@ -16,12 +16,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -42,7 +36,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
@signals.task_postrun.connect
@@ -57,31 +50,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -122,7 +90,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("light")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -322,7 +322,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_processing",
"onyx.server.features.proposal_review.engine",
]
)
)

View File

@@ -79,15 +79,6 @@ beat_task_templates: list[dict] = [
"skip_gated": False,
},
},
{
"name": "check-for-dangling-import-jobs",
"task": OnyxCeleryTask.CHECK_FOR_DANGLING_IMPORT_JOBS,
"schedule": timedelta(minutes=10),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-index-attempt-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,

View File

@@ -59,11 +59,6 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
from onyx.server.metrics.deletion_metrics import inc_deletion_started
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
@@ -107,7 +102,7 @@ def revoke_tasks_blocking_deletion(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
@@ -115,7 +110,7 @@ def revoke_tasks_blocking_deletion(
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking pruning task")
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
@@ -305,7 +300,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
inc_deletion_blocked(tenant_id, "indexing")
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
@@ -313,13 +307,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
)
if redis_connector.prune.fenced:
inc_deletion_blocked(tenant_id, "pruning")
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
inc_deletion_blocked(tenant_id, "permissions")
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
)
@@ -367,7 +359,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
# set this only after all tasks have been added
fence_payload.num_tasks = tasks_generated
redis_connector.delete.set_fence(fence_payload)
inc_deletion_started(tenant_id)
return tasks_generated
@@ -517,11 +508,7 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
connector_id=connector_id_to_delete,
)
if not connector:
task_logger.info(
"Connector deletion - Connector already deleted, skipping connector cleanup"
)
elif not len(connector.credentials):
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
@@ -536,12 +523,6 @@ def monitor_connector_deletion_taskset(
num_docs_synced=fence_data.num_tasks,
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "success", duration)
inc_deletion_completed(tenant_id, "success")
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
@@ -560,11 +541,6 @@ def monitor_connector_deletion_taskset(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
duration = (
datetime.now(timezone.utc) - fence_data.submitted
).total_seconds()
observe_deletion_taskset_duration(tenant_id, "failure", duration)
inc_deletion_completed(tenant_id, "failure")
raise e
task_logger.info(
@@ -741,6 +717,5 @@ def validate_connector_deletion_fence(
f"fence={fence_key}"
)
inc_deletion_fence_reset(tenant_id)
redis_connector.delete.reset()
return

View File

@@ -172,10 +172,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
search_settings = get_current_search_settings(db_session)
indexing_setting = IndexingSetting.from_db_model(search_settings)
task_logger.debug(
"Verified tenant info, migration record, and search settings."
)
# 2.e. Build sanitized to original doc ID mapping to check for
# conflicts in the event we sanitize a doc ID to an
# already-existing doc ID.
@@ -329,7 +325,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
finally:
if lock.owned():
lock.release()
task_logger.debug("Released the OpenSearch migration lock.")
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the migration task."

View File

@@ -38,7 +38,6 @@ from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
from onyx.db.connector_credential_pair import get_connector_credential_pair
@@ -526,14 +525,6 @@ def connector_pruning_generator_task(
return None
try:
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
# The session is closed before enumeration so the DB connection is not held
# open during the 1030+ minute connector crawl.
connector_source: DocumentSource | None = None
connector_type: str = ""
is_connector_public: bool = False
runnable_connector: BaseConnector | None = None
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
@@ -559,51 +550,49 @@ def connector_pruning_generator_task(
)
redis_connector.prune.set_fence(new_payload)
connector_source = cc_pair.connector.source
connector_type = connector_source.value
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
task_logger.info(
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
connector_source,
cc_pair.connector.source,
InputType.SLIM_RETRIEVAL,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
)
# Session 1 closed here — connection released before enumeration.
callback = PruneCallback(
0,
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
callback = PruneCallback(
0,
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
# Extract docs and hierarchy nodes from the source (no DB session held).
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Extract docs and hierarchy nodes from the source
connector_type = cc_pair.connector.source.value
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
with get_session_with_current_tenant() as db_session:
source = connector_source
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes: list[DBHierarchyNode] = []
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=source,
commit=False,
commit=True,
is_connector_public=is_connector_public,
)
@@ -612,13 +601,9 @@ def connector_pruning_generator_task(
hierarchy_node_ids=[n.id for n in upserted_nodes],
connector_id=connector_id,
credential_id=credential_id,
commit=False,
commit=True,
)
# Single commit so the FK reference in the join table can never
# outrun the parent hierarchy_node insert.
db_session.commit()
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in upserted_nodes
@@ -673,7 +658,7 @@ def connector_pruning_generator_task(
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={connector_source} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)

View File

@@ -23,8 +23,6 @@ class IndexAttemptErrorPydantic(BaseModel):
index_attempt_id: int
error_type: str | None = None
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
@@ -39,5 +37,4 @@ class IndexAttemptErrorPydantic(BaseModel):
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
error_type=model.error_type,
)

View File

@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.info(f"Cache miss for file with id={file_id}")
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()

View File

@@ -4,6 +4,8 @@ from collections.abc import Callable
from typing import Any
from typing import Literal
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import CitationMapping
@@ -633,6 +635,7 @@ def run_llm_loop(
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -1017,16 +1020,20 @@ def run_llm_loop(
persisted_memory_id: int | None = None
if user_memory_context and user_memory_context.user_id:
if tool_response.rich_response.index_to_replace is not None:
persisted_memory_id = update_memory_at_index(
memory = update_memory_at_index(
user_id=user_memory_context.user_id,
index=tool_response.rich_response.index_to_replace,
new_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id if memory else None
else:
persisted_memory_id = add_memory(
memory = add_memory(
user_id=user_memory_context.user_id,
memory_text=tool_response.rich_response.memory_text,
db_session=db_session,
)
persisted_memory_id = memory.id
operation: Literal["add", "update"] = (
"update"
if tool_response.rich_response.index_to_replace is not None

View File

@@ -67,6 +67,7 @@ from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import reserve_multi_model_message_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
@@ -1005,86 +1006,93 @@ def _run_models(
model_llm = setup.llms[model_idx]
try:
# Each function opens short-lived DB sessions on demand.
# Do NOT pass a long-lived session here — it would hold a
# connection for the entire LLM loop (minutes), and cloud
# infrastructure may drop idle connections.
thread_tool_dict = construct_tools(
persona=setup.persona,
emitter=model_emitter,
user=user,
llm=model_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool for tool_list in thread_tool_dict.values() for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError("Deep research is not supported for projects")
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
# Do NOT write to the outer db_session (or any shared DB state) from here;
# all DB writes in this thread must go through thread_db_session.
with get_session_with_current_tenant() as thread_db_session:
thread_tool_dict = construct_tools(
persona=setup.persona,
user_memory_context=setup.user_memory_context,
db_session=thread_db_session,
emitter=model_emitter,
user=user,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
search_tool_config=SearchToolConfig(
user_selected_filters=setup.new_msg_req.internal_search_filters,
project_id_filter=setup.search_params.project_id_filter,
persona_id_filter=setup.search_params.persona_id_filter,
bypass_acl=setup.bypass_acl,
slack_context=setup.slack_context,
enable_slack_search=_should_enable_slack_search(
setup.persona, setup.new_msg_req.internal_search_filters
),
),
custom_tool_config=CustomToolConfig(
chat_session_id=setup.chat_session.id,
message_id=setup.user_message.id,
additional_headers=setup.custom_tool_additional_headers,
mcp_headers=setup.mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
user_file_ids=setup.available_files.user_file_ids,
chat_file_ids=setup.available_files.chat_file_ids,
),
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=setup.search_params.search_usage,
)
model_tools = [
tool
for tool_list in thread_tool_dict.values()
for tool in tool_list
]
if setup.forced_tool_id and setup.forced_tool_id not in {
tool.id for tool in model_tools
}:
raise ValueError(
f"Forced tool {setup.forced_tool_id} not found in tools"
)
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
if n_models == 1 and setup.new_msg_req.deep_research:
if setup.chat_session.project_id:
raise RuntimeError(
"Deep research is not supported for projects"
)
run_deep_research_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
skip_clarification=setup.skip_clarification,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
all_injected_file_metadata=setup.all_injected_file_metadata,
)
else:
run_llm_loop(
emitter=model_emitter,
state_container=sc,
simple_chat_history=list(setup.simple_chat_history),
tools=model_tools,
custom_agent_prompt=setup.custom_agent_prompt,
context_files=setup.extracted_context_files,
persona=setup.persona,
user_memory_context=setup.user_memory_context,
llm=model_llm,
token_counter=get_llm_token_counter(model_llm),
db_session=thread_db_session,
forced_tool_id=setup.forced_tool_id,
user_identity=setup.user_identity,
chat_session_id=str(setup.chat_session.id),
chat_files=setup.chat_files_for_tools,
include_citations=setup.new_msg_req.include_citations,
all_injected_file_metadata=setup.all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
model_succeeded[model_idx] = True

View File

@@ -449,7 +449,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
OPENSEARCH_MIGRATION_BEAT_LOCK = "da_lock:opensearch_migration_beat"
CHECK_DANGLING_IMPORT_JOBS_BEAT_LOCK = "da_lock:check_dangling_import_jobs_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
@@ -613,9 +612,6 @@ class OnyxCeleryTask:
# Hook execution log retention
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
# Proposal review import cleanup
CHECK_FOR_DANGLING_IMPORT_JOBS = "check_for_dangling_import_jobs"
# Sandbox cleanup
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"

View File

@@ -61,9 +61,6 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 5
_SERVER_ERROR_CODES = {500, 502, 503, 504}
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
@@ -572,8 +569,7 @@ class OnyxConfluence:
if not limit:
limit = _DEFAULT_PAGINATION_LIMIT
current_limit = limit
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
while url_suffix:
logger.debug(f"Making confluence call to {url_suffix}")
@@ -613,61 +609,40 @@ class OnyxConfluence:
)
continue
if raw_response.status_code in _SERVER_ERROR_CODES:
# Try reducing the page size -- Confluence often times out
# on large result sets (especially Cloud 504s).
if current_limit > _MINIMUM_PAGINATION_LIMIT:
old_limit = current_limit
current_limit = max(
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
)
logger.warning(
f"Confluence returned {raw_response.status_code}. "
f"Reducing limit from {old_limit} to {current_limit} "
f"and retrying."
)
url_suffix = update_param_in_path(
url_suffix, "limit", str(current_limit)
)
continue
# If we fail due to a 500, try one by one.
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
# pagination
if raw_response.status_code == 500 and not self._is_cloud:
initial_start = get_start_param_from_url(url_suffix)
if initial_start is None:
# can't handle this if we don't have offset-based pagination
raise
# Limit reduction exhausted -- for Server, fall back to
# one-by-one offset pagination as a last resort.
if not self._is_cloud:
initial_start = get_start_param_from_url(url_suffix)
# this will just yield the successful items from the batch
new_url_suffix = (
yield from self._try_one_by_one_for_paginated_url(
url_suffix,
initial_start=initial_start,
limit=current_limit,
)
)
# this means we ran into an empty page
if new_url_suffix is None:
if next_page_callback:
next_page_callback("")
break
# this will just yield the successful items from the batch
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
url_suffix,
initial_start=initial_start,
limit=limit,
)
url_suffix = new_url_suffix
continue
# this means we ran into an empty page
if new_url_suffix is None:
if next_page_callback:
next_page_callback("")
break
url_suffix = new_url_suffix
continue
else:
logger.exception(
f"Error in confluence call to {url_suffix} "
f"after reducing limit to {current_limit}.\n"
f"Raw Response Text: {raw_response.text}\n"
f"Error: {e}\n"
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
raise
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
raise
try:
next_response = raw_response.json()
except Exception as e:
@@ -705,10 +680,6 @@ class OnyxConfluence:
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
if url_suffix and current_limit != limit:
url_suffix = update_param_in_path(
url_suffix, "limit", str(current_limit)
)
for i, result in enumerate(results):
updated_start += 1
if url_suffix and next_page_callback and i == len(results) - 1:

View File

@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
from onyx.connectors.google_drive.file_retrieval import (
get_files_by_web_view_links_batch,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
@@ -73,13 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import NormalizationResult
from onyx.connectors.interfaces import Resolver
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -207,9 +202,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
Resolver,
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1672,82 +1665,6 @@ class GoogleDriveConnector(
start, end, checkpoint, include_permissions=True
)
@override
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
if self._creds is None or self._primary_admin_email is None:
raise RuntimeError(
"Credentials missing, should not call this method before calling load_credentials"
)
logger.info(f"Resolving {len(errors)} errors")
doc_ids = [
failure.failed_document.document_id
for failure in errors
if failure.failed_document
]
service = get_drive_service(self.creds, self.primary_admin_email)
field_type = (
DriveFileFieldType.WITH_PERMISSIONS
if include_permissions or self.exclude_domain_link_only
else DriveFileFieldType.STANDARD
)
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
for doc_id, error in batch_result.errors.items():
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=doc_id,
),
failure_message=f"Failed to retrieve file during error resolution: {error}",
exception=error,
)
permission_sync_context = (
PermissionSyncContext(
primary_admin_email=self.primary_admin_email,
google_domain=self.google_domain,
)
if include_permissions
else None
)
retrieved_files = [
RetrievedDriveFile(
drive_file=file,
user_email=self.primary_admin_email,
completion_stage=DriveRetrievalStage.DONE,
)
for file in batch_result.files.values()
]
yield from self._get_new_ancestors_for_files(
files=retrieved_files,
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
permission_sync_context=permission_sync_context,
add_prefix=True,
)
func_with_args = [
(
self._convert_retrieved_file_to_document,
(rf, permission_sync_context),
)
for rf in retrieved_files
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
for result in results:
if result is not None:
yield result
def _extract_slim_docs_from_google_drive(
self,
checkpoint: GoogleDriveCheckpoint,

View File

@@ -9,7 +9,6 @@ from urllib.parse import urlparse
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
MAX_BATCH_SIZE = 100
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().list() based on the field type enum."""
"""Get the appropriate fields string based on the field type enum"""
if field_type == DriveFileFieldType.SLIM:
return SLIM_FILE_FIELDS
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
return FILE_FIELDS
def _extract_single_file_fields(list_fields: str) -> str:
"""Convert a files().list() fields string to one suitable for files().get().
List fields look like "nextPageToken, files(field1, field2, ...)"
Single-file fields should be just "field1, field2, ..."
"""
start = list_fields.find("files(")
if start == -1:
return list_fields
inner_start = start + len("files(")
inner_end = list_fields.rfind(")")
return list_fields[inner_start:inner_end]
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
"""Get the appropriate fields string for files().get() based on the field type enum."""
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
def _get_files_in_parent(
service: Resource,
parent_id: str,
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
)
.execute()
)
class BatchRetrievalResult:
"""Result of a batch file retrieval, separating successes from errors."""
def __init__(self) -> None:
self.files: dict[str, GoogleDriveFileType] = {}
self.errors: dict[str, Exception] = {}
def get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
field_type: DriveFileFieldType,
) -> BatchRetrievalResult:
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
Returns a BatchRetrievalResult containing successful file retrievals
and errors for any files that could not be fetched.
Automatically splits into chunks of MAX_BATCH_SIZE.
"""
fields = _get_single_file_fields(field_type)
if len(web_view_links) <= MAX_BATCH_SIZE:
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
combined = BatchRetrievalResult()
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
combined.files.update(chunk_result.files)
combined.errors.update(chunk_result.errors)
return combined
def _get_files_by_web_view_links_batch(
service: GoogleDriveService,
web_view_links: list[str],
fields: str,
) -> BatchRetrievalResult:
"""Single-batch implementation."""
result = BatchRetrievalResult()
def callback(
request_id: str,
response: GoogleDriveFileType,
exception: Exception | None,
) -> None:
if exception:
logger.warning(f"Error retrieving file {request_id}: {exception}")
result.errors[request_id] = exception
else:
result.files[request_id] = response
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,
)
batch.add(request, request_id=web_view_link)
except ValueError as e:
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
result.errors[web_view_link] = e
batch.execute()
return result

View File

@@ -298,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
raise NotImplementedError
class Resolver(BaseConnector):
@abc.abstractmethod
def resolve_errors(
self,
errors: list[ConnectorFailure],
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
If include_permissions is True, the documents will have permissions synced.
May also yield HierarchyNode objects for ancestor folders of resolved documents.
"""
raise NotImplementedError
class HierarchyConnector(BaseConnector):
@abc.abstractmethod
def load_hierarchy(

View File

@@ -8,7 +8,6 @@ from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import BytesIO
from typing import Any
import requests
@@ -41,7 +40,6 @@ from onyx.connectors.jira.utils import best_effort_basic_expert_info
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
from onyx.connectors.jira.utils import build_jira_client
from onyx.connectors.jira.utils import build_jira_url
from onyx.connectors.jira.utils import CustomFieldExtractor
from onyx.connectors.jira.utils import extract_text_from_adf
from onyx.connectors.jira.utils import get_comment_strs
from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION
@@ -54,7 +52,6 @@ from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.enums import HierarchyNodeType
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -67,7 +64,6 @@ _MAX_RESULTS_FETCH_IDS = 5000
_JIRA_FULL_PAGE_SIZE = 50
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
_JIRA_BULK_FETCH_LIMIT = 100
_MAX_ATTACHMENT_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
# Constants for Jira field names
_FIELD_REPORTER = "reporter"
@@ -381,7 +377,6 @@ def process_jira_issue(
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
parent_hierarchy_raw_node_id: str | None = None,
custom_fields_mapping: dict[str, str] | None = None,
) -> Document | None:
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
@@ -467,24 +462,6 @@ def process_jira_issue(
else:
logger.error(f"Project should exist but does not for {issue.key}")
# Merge custom fields into metadata if a mapping was provided
if custom_fields_mapping:
try:
custom_fields = CustomFieldExtractor.get_issue_custom_fields(
issue, custom_fields_mapping
)
# Filter out custom fields that collide with existing metadata keys
for key in list(custom_fields.keys()):
if key in metadata_dict:
logger.warning(
f"Custom field '{key}' on {issue.key} collides with "
f"standard metadata key; skipping custom field value"
)
del custom_fields[key]
metadata_dict.update(custom_fields)
except Exception as e:
logger.warning(f"Failed to extract custom fields for {issue.key}: {e}")
return Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
@@ -527,12 +504,6 @@ class JiraConnector(
# Custom JQL query to filter Jira issues
jql_query: str | None = None,
scoped_token: bool = False,
# When True, extract custom fields from Jira issues and include them
# in document metadata with human-readable field names.
extract_custom_fields: bool = False,
# When True, download attachments from Jira issues and yield them
# as separate Documents linked to the parent ticket.
fetch_attachments: bool = False,
) -> None:
self.batch_size = batch_size
@@ -546,11 +517,7 @@ class JiraConnector(
self.labels_to_skip = set(labels_to_skip)
self.jql_query = jql_query
self.scoped_token = scoped_token
self.extract_custom_fields = extract_custom_fields
self.fetch_attachments = fetch_attachments
self._jira_client: JIRA | None = None
# Mapping of custom field IDs to human-readable names (populated on load_credentials)
self._custom_fields_mapping: dict[str, str] = {}
# Cache project permissions to avoid fetching them repeatedly across runs
self._project_permissions_cache: dict[str, Any] = {}
@@ -711,134 +678,12 @@ class JiraConnector(
# the document belongs directly under the project in the hierarchy
return project_key
def _process_attachments(
self,
issue: Issue,
parent_hierarchy_raw_node_id: str | None,
include_permissions: bool = False,
project_key: str | None = None,
) -> Generator[Document | ConnectorFailure, None, None]:
"""Download and yield Documents for each attachment on a Jira issue.
Each attachment becomes a separate Document whose text is extracted
from the downloaded file content. Failures on individual attachments
are logged and yielded as ConnectorFailure so they never break the
overall indexing run.
"""
attachments = best_effort_get_field_from_issue(issue, "attachment")
if not attachments:
return
issue_url = build_jira_url(self.jira_base, issue.key)
for attachment in attachments:
try:
filename = getattr(attachment, "filename", "unknown")
try:
size = int(getattr(attachment, "size", 0) or 0)
except (ValueError, TypeError):
size = 0
content_url = getattr(attachment, "content", None)
attachment_id = getattr(attachment, "id", filename)
mime_type = getattr(attachment, "mimeType", "application/octet-stream")
created = getattr(attachment, "created", None)
if size > _MAX_ATTACHMENT_SIZE_BYTES:
logger.warning(
f"Skipping attachment '{filename}' on {issue.key}: "
f"size {size} bytes exceeds {_MAX_ATTACHMENT_SIZE_BYTES} byte limit"
)
continue
if not content_url:
logger.warning(
f"Skipping attachment '{filename}' on {issue.key}: "
f"no content URL available"
)
continue
# Download the attachment using the public API on the
# python-jira Attachment resource (avoids private _session access
# and the double-copy from response.content + BytesIO wrapping).
file_content = attachment.get()
# Extract text from the downloaded file
try:
text = extract_file_text(
file=BytesIO(file_content),
file_name=filename,
)
except Exception as e:
logger.warning(
f"Could not extract text from attachment '{filename}' "
f"on {issue.key}: {e}"
)
continue
if not text or not text.strip():
logger.info(
f"Skipping attachment '{filename}' on {issue.key}: "
f"no text content could be extracted"
)
continue
doc_id = f"{issue_url}/attachments/{attachment_id}"
attachment_doc = Document(
id=doc_id,
sections=[TextSection(link=issue_url, text=text)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {filename}",
title=filename,
doc_updated_at=(time_str_to_utc(created) if created else None),
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
metadata={
"parent_ticket": issue.key,
"attachment_filename": filename,
"attachment_mime_type": mime_type,
"attachment_size": str(size),
},
)
if include_permissions and project_key:
attachment_doc.external_access = self._get_project_permissions(
project_key,
add_prefix=True,
)
yield attachment_doc
except Exception as e:
logger.error(f"Failed to process attachment on {issue.key}: {e}")
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{issue_url}/attachments/{getattr(attachment, 'id', 'unknown')}",
document_link=issue_url,
),
failure_message=f"Failed to process attachment '{getattr(attachment, 'filename', 'unknown')}': {str(e)}",
exception=e,
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
scoped_token=self.scoped_token,
)
# Fetch the custom field ID-to-name mapping once at credential load time.
# This avoids repeated API calls during issue processing.
if self.extract_custom_fields:
try:
self._custom_fields_mapping = (
CustomFieldExtractor.get_all_custom_fields(self._jira_client)
)
logger.info(
f"Loaded {len(self._custom_fields_mapping)} custom field definitions"
)
except Exception as e:
logger.warning(
f"Failed to fetch custom field definitions; "
f"custom field extraction will be skipped: {e}"
)
self._custom_fields_mapping = {}
return None
def _get_jql_query(
@@ -969,11 +814,6 @@ class JiraConnector(
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
custom_fields_mapping=(
self._custom_fields_mapping
if self._custom_fields_mapping
else None
),
):
# Add permission information to the document if requested
if include_permissions:
@@ -983,15 +823,6 @@ class JiraConnector(
)
yield document
# Yield attachment documents if enabled
if self.fetch_attachments:
yield from self._process_attachments(
issue=issue,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
include_permissions=include_permissions,
project_key=project_key,
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
@@ -1099,41 +930,20 @@ class JiraConnector(
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
doc_id = build_jira_url(self.jira_base, issue_key)
parent_hierarchy_raw_node_id = (
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
)
project_perms = self._get_project_permissions(
project_key, add_prefix=False
)
slim_doc_batch.append(
SlimDocument(
id=doc_id,
# Permission sync path - don't prefix, upsert_document_external_perms handles it
external_access=project_perms,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
# Also emit SlimDocument entries for each attachment
if self.fetch_attachments:
attachments = best_effort_get_field_from_issue(issue, "attachment")
if attachments:
for attachment in attachments:
attachment_id = getattr(
attachment,
"id",
getattr(attachment, "filename", "unknown"),
)
slim_doc_batch.append(
SlimDocument(
id=f"{doc_id}/attachments/{attachment_id}",
external_access=project_perms,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
)
current_offset += 1
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from enum import Enum
from typing import Any
from typing import cast
from typing import Literal
from pydantic import BaseModel
from pydantic import Field
@@ -34,17 +33,9 @@ class ConnectorMissingCredentialError(PermissionError):
)
class SectionType(str, Enum):
"""Discriminator for Section subclasses."""
TEXT = "text"
IMAGE = "image"
class Section(BaseModel):
"""Base section class with common attributes"""
type: SectionType
link: str | None = None
text: str | None = None
image_file_id: str | None = None
@@ -53,7 +44,6 @@ class Section(BaseModel):
class TextSection(Section):
"""Section containing text content"""
type: Literal[SectionType.TEXT] = SectionType.TEXT
text: str
def __sizeof__(self) -> int:
@@ -63,7 +53,6 @@ class TextSection(Section):
class ImageSection(Section):
"""Section containing an image reference"""
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
image_file_id: str
def __sizeof__(self) -> int:
@@ -145,6 +134,7 @@ class BasicExpertInfo(BaseModel):
@classmethod
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
first_name = cast(str, model_dict.get("FirstName"))
last_name = cast(str, model_dict.get("LastName"))
email = cast(str, model_dict.get("Email"))

View File

@@ -335,7 +335,6 @@ def update_document_set(
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
)
document_set_row.name = document_set_update_request.name
document_set_row.description = document_set_update_request.description
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False

View File

@@ -11,7 +11,6 @@ from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_PASSWORD
@@ -347,25 +346,6 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _safe_close_session(session: Session) -> None:
"""Close a session, catching connection-closed errors during cleanup.
Long-running operations (e.g. multi-model LLM loops) can hold a session
open for minutes. If the underlying connection is dropped by cloud
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
timeouts, etc.), the implicit rollback in Session.close() raises
OperationalError or InterfaceError. Since the work is already complete,
we log and move on — SQLAlchemy internally invalidates the connection
for pool recycling.
"""
try:
session.close()
except DBAPIError:
logger.warning(
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
)
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
@@ -378,11 +358,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
session = Session(bind=engine, expire_on_commit=False)
try:
with Session(bind=engine, expire_on_commit=False) as session:
yield session
finally:
_safe_close_session(session)
return
# Create connection with schema translation to handle querying the right schema
@@ -390,11 +367,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
session = Session(bind=connection, expire_on_commit=False)
try:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
finally:
_safe_close_session(session)
def get_session() -> Generator[Session, None, None]:

View File

@@ -899,7 +899,6 @@ def create_index_attempt_error(
failure: ConnectorFailure,
db_session: Session,
) -> int:
exc = failure.exception
new_error = IndexAttemptError(
index_attempt_id=index_attempt_id,
connector_credential_pair_id=connector_credential_pair_id,
@@ -922,7 +921,6 @@ def create_index_attempt_error(
),
failure_message=failure.failure_message,
is_resolved=False,
error_type=type(exc).__name__ if exc else None,
)
db_session.add(new_error)
db_session.commit()

View File

@@ -5,7 +5,6 @@ from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.models import Memory
from onyx.db.models import User
@@ -84,51 +83,47 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
def add_memory(
user_id: UUID,
memory_text: str,
db_session: Session | None = None,
) -> int:
db_session: Session,
) -> Memory:
"""Insert a new Memory row for the given user.
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
one (lowest id) is deleted before inserting the new one.
Returns the id of the newly created Memory row.
"""
with get_session_with_current_tenant_if_none(db_session) as db_session:
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
existing = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
if len(existing) >= MAX_MEMORIES_PER_USER:
db_session.delete(existing[0])
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory.id
memory = Memory(
user_id=user_id,
memory_text=memory_text,
)
db_session.add(memory)
db_session.commit()
return memory
def update_memory_at_index(
user_id: UUID,
index: int,
new_text: str,
db_session: Session | None = None,
) -> int | None:
db_session: Session,
) -> Memory | None:
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
Returns the id of the updated Memory row, or None if the index is out of range.
Returns the updated Memory row, or None if the index is out of range.
"""
with get_session_with_current_tenant_if_none(db_session) as db_session:
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
).all()
if index < 0 or index >= len(memory_rows):
return None
if index < 0 or index >= len(memory_rows):
return None
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory.id
memory = memory_rows[index]
memory.memory_text = new_text
db_session.commit()
return memory

View File

@@ -2422,8 +2422,6 @@ class IndexAttemptError(Base):
failure_message: Mapped[str] = mapped_column(Text)
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),

View File

@@ -7,6 +7,8 @@ import time
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
@@ -20,7 +22,6 @@ from onyx.chat.models import LlmStepResult
from onyx.chat.models import ToolCallSimple
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
from onyx.configs.constants import MessageType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
@@ -183,14 +184,6 @@ def generate_final_report(
return has_reasoned
def _get_research_agent_tool_id() -> int:
with get_session_with_current_tenant() as db_session:
return get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id
@log_function_time(print_only=True)
def run_deep_research_llm_loop(
emitter: Emitter,
@@ -200,6 +193,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt: str | None, # noqa: ARG001
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
skip_clarification: bool = False,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
@@ -723,7 +717,6 @@ def run_deep_research_llm_loop(
simple_chat_history.append(assistant_with_tools)
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
research_agent_tool_id = _get_research_agent_tool_id()
for tab_index, report in enumerate(
research_results.intermediate_reports
):
@@ -744,7 +737,10 @@ def run_deep_research_llm_loop(
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=research_agent_tool_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,

View File

@@ -1,3 +1,5 @@
from typing import cast
from chonkie import SentenceChunker
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
@@ -14,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from onyx.connectors.models import IndexingDocument
from onyx.indexing.chunking import DocumentChunker
from onyx.indexing.chunking import extract_blurb
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
@@ -150,6 +154,9 @@ class Chunker:
self.tokenizer = tokenizer
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
# Create a token counter function that returns the count instead of the tokens
def token_counter(text: str) -> int:
return len(tokenizer.encode(text))
@@ -179,12 +186,234 @@ class Chunker:
else None
)
self._document_chunker = DocumentChunker(
tokenizer=tokenizer,
blurb_splitter=self.blurb_splitter,
chunk_splitter=self.chunk_splitter,
mini_chunk_splitter=self.mini_chunk_splitter,
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
"""
Splits the text into smaller chunks based on token count to ensure
no chunk exceeds the content_token_limit.
"""
tokens = self.tokenizer.tokenize(text)
chunks = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks
def _extract_blurb(self, text: str) -> str:
"""
Extract a short blurb from the text (first chunk of size `blurb_size`).
"""
# chunker is in `text` mode
texts = cast(list[str], self.blurb_splitter.chunk(text))
if not texts:
return ""
return texts[0]
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
"""
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
"""
if self.mini_chunk_splitter and chunk_text.strip():
# chunker is in `text` mode
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
return None
# ADDED: extra param image_url to store in the chunk
def _create_chunk(
self,
document: IndexingDocument,
chunks_list: list[DocAwareChunk],
text: str,
links: dict[int, str],
is_continuation: bool = False,
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
image_file_id: str | None = None,
) -> None:
"""
Helper to create a new DocAwareChunk, append it to chunks_list.
"""
new_chunk = DocAwareChunk(
source_document=document,
chunk_id=len(chunks_list),
blurb=self._extract_blurb(text),
content=text,
source_links=links or {0: ""},
image_file_id=image_file_id,
section_continuation=is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
)
chunks_list.append(new_chunk)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
sections: list[Section],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Loops through sections of the document, converting them into one or more chunks.
Works with processed sections that are base Section objects.
"""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(str(section.text or ""))
section_link_text = section.link or ""
image_url = section.image_file_id
# If there is no useful content, skip
if not section_text and (not document.title or section_idx > 0):
logger.warning(
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
)
continue
# CASE 1: If this section has an image, force a separate chunk
if image_url:
# First, if we have any partially built text chunk, finalize it
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
is_continuation=False,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
# Create a chunk specifically for this image section
# (Using the text summary that was generated during processing)
self._create_chunk(
document,
chunks,
section_text,
links={0: section_link_text} if section_link_text else {},
image_file_id=image_url,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
# Continue to next section
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
# chunker is in `text` mode
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
for i, split_text in enumerate(split_texts):
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for j, small_chunk in enumerate(smaller_chunks):
self._create_chunk(
document,
chunks,
small_chunk,
{0: section_link_text},
is_continuation=(j != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
else:
self._create_chunk(
document,
chunks,
split_text,
{0: section_link_text},
is_continuation=(i != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
if chunk_text:
chunk_text += SECTION_SEPARATOR
chunk_text += section_text
link_offsets[current_offset] = section_link_text
else:
# finalize the existing chunk
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
# start a new chunk
link_offsets = {0: section_link_text}
chunk_text = section_text
# finalize any leftover text chunk
if chunk_text.strip() or not chunks:
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets or {0: ""}, # safe default
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
return chunks
def _handle_single_document(
self, document: IndexingDocument
@@ -194,10 +423,7 @@ class Chunker:
logger.debug(f"Chunking {document.semantic_identifier}")
# Title prep
title = extract_blurb(
document.get_title_for_document_index() or "",
self.blurb_splitter,
)
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.encode(title_prefix))
@@ -265,7 +491,7 @@ class Chunker:
# Use processed_sections if available (IndexingDocument), otherwise use original sections
sections_to_chunk = document.processed_sections
normal_chunks = self._document_chunker.chunk(
normal_chunks = self._chunk_document_with_sections(
document,
sections_to_chunk,
title_prefix,

View File

@@ -1,7 +0,0 @@
from onyx.indexing.chunking.document_chunker import DocumentChunker
from onyx.indexing.chunking.section_chunker import extract_blurb
__all__ = [
"DocumentChunker",
"extract_blurb",
]

View File

@@ -1,109 +0,0 @@
from chonkie import SentenceChunker
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import SectionType
from onyx.indexing.chunking.image_section_chunker import ImageChunker
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.text_section_chunker import TextChunker
from onyx.indexing.models import DocAwareChunk
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
logger = setup_logger()
class DocumentChunker:
"""Converts a document's processed sections into DocAwareChunks.
Drop-in replacement for `Chunker._chunk_document_with_sections`.
"""
def __init__(
self,
tokenizer: BaseTokenizer,
blurb_splitter: SentenceChunker,
chunk_splitter: SentenceChunker,
mini_chunk_splitter: SentenceChunker | None = None,
) -> None:
self.blurb_splitter = blurb_splitter
self.mini_chunk_splitter = mini_chunk_splitter
self._dispatch: dict[SectionType, SectionChunker] = {
SectionType.TEXT: TextChunker(
tokenizer=tokenizer,
chunk_splitter=chunk_splitter,
),
SectionType.IMAGE: ImageChunker(),
}
def chunk(
self,
document: IndexingDocument,
sections: list[Section],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
payloads = self._collect_section_payloads(
document=document,
sections=sections,
content_token_limit=content_token_limit,
)
if not payloads:
payloads.append(ChunkPayload(text="", links={0: ""}))
return [
payload.to_doc_aware_chunk(
document=document,
chunk_id=idx,
blurb_splitter=self.blurb_splitter,
mini_chunk_splitter=self.mini_chunk_splitter,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
for idx, payload in enumerate(payloads)
]
def _collect_section_payloads(
self,
document: IndexingDocument,
sections: list[Section],
content_token_limit: int,
) -> list[ChunkPayload]:
accumulator = AccumulatorState()
payloads: list[ChunkPayload] = []
for section_idx, section in enumerate(sections):
section_text = clean_text(str(section.text or ""))
if not section_text and (not document.title or section_idx > 0):
logger.warning(
f"Skipping empty or irrelevant section in doc "
f"{document.semantic_identifier}, link={section.link}"
)
continue
chunker = self._select_chunker(section)
result = chunker.chunk_section(
section=section,
accumulator=accumulator,
content_token_limit=content_token_limit,
)
payloads.extend(result.payloads)
accumulator = result.accumulator
payloads.extend(accumulator.flush_to_list())
return payloads
def _select_chunker(self, section: Section) -> SectionChunker:
try:
return self._dispatch[section.type]
except KeyError:
raise ValueError(f"No SectionChunker registered for type={section.type}")

View File

@@ -1,35 +0,0 @@
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.utils.text_processing import clean_text
class ImageChunker(SectionChunker):
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int, # noqa: ARG002
) -> SectionChunkerOutput:
assert section.image_file_id is not None
section_text = clean_text(str(section.text or ""))
section_link = section.link or ""
# Flush any partially built text chunks
payloads = accumulator.flush_to_list()
payloads.append(
ChunkPayload(
text=section_text,
links={0: section_link} if section_link else {},
image_file_id=section.image_file_id,
is_continuation=False,
)
)
return SectionChunkerOutput(
payloads=payloads,
accumulator=AccumulatorState(),
)

View File

@@ -1,100 +0,0 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from typing import cast
from chonkie import SentenceChunker
from pydantic import BaseModel
from pydantic import Field
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.models import DocAwareChunk
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
texts = cast(list[str], blurb_splitter.chunk(text))
if not texts:
return ""
return texts[0]
def get_mini_chunk_texts(
chunk_text: str,
mini_chunk_splitter: SentenceChunker | None,
) -> list[str] | None:
if mini_chunk_splitter and chunk_text.strip():
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
return None
class ChunkPayload(BaseModel):
"""Section-local chunk content without document-scoped fields.
The orchestrator upgrades these to DocAwareChunks via
`to_doc_aware_chunk` after assigning chunk_ids and attaching
title/metadata.
"""
text: str
links: dict[int, str]
is_continuation: bool = False
image_file_id: str | None = None
def to_doc_aware_chunk(
self,
document: IndexingDocument,
chunk_id: int,
blurb_splitter: SentenceChunker,
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
mini_chunk_splitter: SentenceChunker | None = None,
) -> DocAwareChunk:
return DocAwareChunk(
source_document=document,
chunk_id=chunk_id,
blurb=extract_blurb(self.text, blurb_splitter),
content=self.text,
source_links=self.links or {0: ""},
image_file_id=self.image_file_id,
section_continuation=self.is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
)
class AccumulatorState(BaseModel):
"""Cross-section text buffer threaded through SectionChunkers."""
text: str = ""
link_offsets: dict[int, str] = Field(default_factory=dict)
def is_empty(self) -> bool:
return not self.text.strip()
def flush_to_list(self) -> list[ChunkPayload]:
if self.is_empty():
return []
return [ChunkPayload(text=self.text, links=self.link_offsets)]
class SectionChunkerOutput(BaseModel):
payloads: list[ChunkPayload]
accumulator: AccumulatorState
class SectionChunker(ABC):
@abstractmethod
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput: ...

View File

@@ -1,129 +0,0 @@
from typing import cast
from chonkie import SentenceChunker
from onyx.configs.constants import SECTION_SEPARATOR
from onyx.connectors.models import Section
from onyx.indexing.chunking.section_chunker import AccumulatorState
from onyx.indexing.chunking.section_chunker import ChunkPayload
from onyx.indexing.chunking.section_chunker import SectionChunker
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import count_tokens
from onyx.utils.text_processing import clean_text
from onyx.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
class TextChunker(SectionChunker):
def __init__(
self,
tokenizer: BaseTokenizer,
chunk_splitter: SentenceChunker,
) -> None:
self.tokenizer = tokenizer
self.chunk_splitter = chunk_splitter
self.section_separator_token_count = count_tokens(
SECTION_SEPARATOR,
self.tokenizer,
)
def chunk_section(
self,
section: Section,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
section_text = clean_text(str(section.text or ""))
section_link = section.link or ""
section_token_count = len(self.tokenizer.encode(section_text))
# Oversized — flush buffer and split the section
if section_token_count > content_token_limit:
return self._handle_oversized_section(
section_text=section_text,
section_link=section_link,
accumulator=accumulator,
content_token_limit=content_token_limit,
)
current_token_count = count_tokens(accumulator.text, self.tokenizer)
next_section_tokens = self.section_separator_token_count + section_token_count
# Fits — extend the accumulator
if next_section_tokens + current_token_count <= content_token_limit:
offset = len(shared_precompare_cleanup(accumulator.text))
new_text = accumulator.text
if new_text:
new_text += SECTION_SEPARATOR
new_text += section_text
return SectionChunkerOutput(
payloads=[],
accumulator=AccumulatorState(
text=new_text,
link_offsets={**accumulator.link_offsets, offset: section_link},
),
)
# Doesn't fit — flush buffer and restart with this section
return SectionChunkerOutput(
payloads=accumulator.flush_to_list(),
accumulator=AccumulatorState(
text=section_text,
link_offsets={0: section_link},
),
)
def _handle_oversized_section(
self,
section_text: str,
section_link: str,
accumulator: AccumulatorState,
content_token_limit: int,
) -> SectionChunkerOutput:
payloads = accumulator.flush_to_list()
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
for i, split_text in enumerate(split_texts):
if (
STRICT_CHUNK_TOKEN_LIMIT
and count_tokens(split_text, self.tokenizer) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for j, small_chunk in enumerate(smaller_chunks):
payloads.append(
ChunkPayload(
text=small_chunk,
links={0: section_link},
is_continuation=(j != 0),
)
)
else:
payloads.append(
ChunkPayload(
text=split_text,
links={0: section_link},
is_continuation=(i != 0),
)
)
return SectionChunkerOutput(
payloads=payloads,
accumulator=AccumulatorState(),
)
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
tokens = self.tokenizer.tokenize(text)
chunks: list[str] = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks

View File

@@ -542,7 +542,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
**document.model_dump(),
processed_sections=[
Section(
type=section.type,
text=section.text if isinstance(section, TextSection) else "",
link=section.link,
image_file_id=(
@@ -567,7 +566,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
if isinstance(section, ImageSection):
# Default section with image path preserved - ensure text is always a string
processed_section = Section(
type=section.type,
link=section.link,
image_file_id=section.image_file_id,
text="", # Initialize with empty string
@@ -611,7 +609,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
type=section.type,
text=section.text or "", # Ensure text is always a string, not None
link=section.link,
image_file_id=None,

View File

@@ -96,9 +96,6 @@ from onyx.server.features.persona.api import admin_router as admin_persona_route
from onyx.server.features.persona.api import agents_router
from onyx.server.features.persona.api import basic_router as persona_router
from onyx.server.features.projects.api import router as projects_router
from onyx.server.features.proposal_review.api.api import (
router as proposal_review_router,
)
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
@@ -472,7 +469,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, proposal_review_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)

View File

@@ -15,7 +15,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.2.3",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",
@@ -961,9 +961,9 @@
"license": "MIT"
},
"node_modules/@hono/node-server": {
"version": "1.19.13",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.13.tgz",
"integrity": "sha512-TsQLe4i2gvoTtrHje625ngThGBySOgSK3Xo2XRYOdqGN1teR8+I7vchQC46uLJi8OF62YTYA3AhSpumtkhsaKQ==",
"version": "1.19.10",
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
"license": "MIT",
"engines": {
"node": ">=18.14.1"
@@ -1711,9 +1711,9 @@
}
},
"node_modules/@next/env": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.2.3.tgz",
"integrity": "sha512-ZWXyj4uNu4GCWQw9cjRxWlbD+33mcDszIo9iQxFnBX3Wmgq9ulaSJcl6VhuWx5pCWqqD+9W6Wfz7N0lM5lYPMA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
@@ -1727,9 +1727,9 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.2.3.tgz",
"integrity": "sha512-u37KDKTKQ+OQLvY+z7SNXixwo4Q2/IAJFDzU1fYe66IbCE51aDSAzkNDkWmLN0yjTUh4BKBd+hb69jYn6qqqSg==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
"cpu": [
"arm64"
],
@@ -1743,9 +1743,9 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.2.3.tgz",
"integrity": "sha512-gHjL/qy6Q6CG3176FWbAKyKh9IfntKZTB3RY/YOJdDFpHGsUDXVH38U4mMNpHVGXmeYW4wj22dMp1lTfmu/bTQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
"cpu": [
"x64"
],
@@ -1759,9 +1759,9 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.2.3.tgz",
"integrity": "sha512-U6vtblPtU/P14Y/b/n9ZY0GOxbbIhTFuaFR7F4/uMBidCi2nSdaOFhA0Go81L61Zd6527+yvuX44T4ksnf8T+Q==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
"cpu": [
"arm64"
],
@@ -1775,9 +1775,9 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.2.3.tgz",
"integrity": "sha512-/YV0LgjHUmfhQpn9bVoGc4x4nan64pkhWR5wyEV8yCOfwwrH630KpvRg86olQHTwHIn1z59uh6JwKvHq1h4QEw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
"cpu": [
"arm64"
],
@@ -1791,9 +1791,9 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.2.3.tgz",
"integrity": "sha512-/HiWEcp+WMZ7VajuiMEFGZ6cg0+aYZPqCJD3YJEfpVWQsKYSjXQG06vJP6F1rdA03COD9Fef4aODs3YxKx+RDQ==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
"cpu": [
"x64"
],
@@ -1807,9 +1807,9 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.2.3.tgz",
"integrity": "sha512-Kt44hGJfZSefebhk/7nIdivoDr3Ugp5+oNz9VvF3GUtfxutucUIHfIO0ZYO8QlOPDQloUVQn4NVC/9JvHRk9hw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
"cpu": [
"x64"
],
@@ -1823,9 +1823,9 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.2.3.tgz",
"integrity": "sha512-O2NZ9ie3Tq6xj5Z5CSwBT3+aWAMW2PIZ4egUi9MaWLkwaehgtB7YZjPm+UpcNpKOme0IQuqDcor7BsW6QBiQBw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
"cpu": [
"arm64"
],
@@ -1839,9 +1839,9 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.2.3.tgz",
"integrity": "sha512-Ibm29/GgB/ab5n7XKqlStkm54qqZE8v2FnijUPBgrd67FWrac45o/RsNlaOWjme/B5UqeWt/8KM4aWBwA1D2Kw==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
"cpu": [
"x64"
],
@@ -7427,9 +7427,9 @@
}
},
"node_modules/hono": {
"version": "4.12.12",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.12.tgz",
"integrity": "sha512-p1JfQMKaceuCbpJKAPKVqyqviZdS0eUxH9v82oWo1kb9xjQ5wA6iP3FNVAPDFlz5/p7d45lO+BpSk1tuSZMF4Q==",
"version": "4.12.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"
@@ -8637,9 +8637,9 @@
}
},
"node_modules/lodash": {
"version": "4.18.1",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
"version": "4.17.23",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
"license": "MIT"
},
"node_modules/lodash.merge": {
@@ -8978,12 +8978,12 @@
}
},
"node_modules/next": {
"version": "16.2.3",
"resolved": "https://registry.npmjs.org/next/-/next-16.2.3.tgz",
"integrity": "sha512-9V3zV4oZFza3PVev5/poB9g0dEafVcgNyQ8eTRop8GvxZjV2G15FC5ARuG1eFD42QgeYkzJBJzHghNP8Ad9xtA==",
"version": "16.1.7",
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
"license": "MIT",
"dependencies": {
"@next/env": "16.2.3",
"@next/env": "16.1.7",
"@swc/helpers": "0.5.15",
"baseline-browser-mapping": "^2.9.19",
"caniuse-lite": "^1.0.30001579",
@@ -8997,15 +8997,15 @@
"node": ">=20.9.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "16.2.3",
"@next/swc-darwin-x64": "16.2.3",
"@next/swc-linux-arm64-gnu": "16.2.3",
"@next/swc-linux-arm64-musl": "16.2.3",
"@next/swc-linux-x64-gnu": "16.2.3",
"@next/swc-linux-x64-musl": "16.2.3",
"@next/swc-win32-arm64-msvc": "16.2.3",
"@next/swc-win32-x64-msvc": "16.2.3",
"sharp": "^0.34.5"
"@next/swc-darwin-arm64": "16.1.7",
"@next/swc-darwin-x64": "16.1.7",
"@next/swc-linux-arm64-gnu": "16.1.7",
"@next/swc-linux-arm64-musl": "16.1.7",
"@next/swc-linux-x64-gnu": "16.1.7",
"@next/swc-linux-x64-musl": "16.1.7",
"@next/swc-win32-arm64-msvc": "16.1.7",
"@next/swc-win32-x64-msvc": "16.1.7",
"sharp": "^0.34.4"
},
"peerDependencies": {
"@opentelemetry/api": "^1.1.0",

View File

@@ -16,7 +16,7 @@
"date-fns": "^4.1.0",
"embla-carousel-react": "^8.6.0",
"lucide-react": "^0.562.0",
"next": "16.2.3",
"next": "16.1.7",
"next-themes": "^0.4.6",
"radix-ui": "^1.4.3",
"react": "19.2.3",

View File

@@ -63,7 +63,6 @@ class DocumentSetCreationRequest(BaseModel):
class DocumentSetUpdateRequest(BaseModel):
id: int
name: str
description: str
cc_pair_ids: list[int]
is_public: bool

View File

@@ -1,39 +0,0 @@
"""Main router for Proposal Review.
Mounts all sub-routers under /proposal-review prefix.
"""
from fastapi import APIRouter
from fastapi import Depends
from onyx.auth.permissions import require_permission
from onyx.db.enums import Permission
from onyx.server.features.proposal_review.configs import ENABLE_PROPOSAL_REVIEW
router = APIRouter(
prefix="/proposal-review",
dependencies=[Depends(require_permission(Permission.BASIC_ACCESS))],
)
if ENABLE_PROPOSAL_REVIEW:
from onyx.server.features.proposal_review.api.config_api import (
router as config_router,
)
from onyx.server.features.proposal_review.api.decisions_api import (
router as decisions_router,
)
from onyx.server.features.proposal_review.api.proposals_api import (
router as proposals_router,
)
from onyx.server.features.proposal_review.api.review_api import (
router as review_router,
)
from onyx.server.features.proposal_review.api.rulesets_api import (
router as rulesets_router,
)
router.include_router(rulesets_router, tags=["proposal-review"])
router.include_router(proposals_router, tags=["proposal-review"])
router.include_router(review_router, tags=["proposal-review"])
router.include_router(decisions_router, tags=["proposal-review"])
router.include_router(config_router, tags=["proposal-review"])

View File

@@ -1,89 +0,0 @@
"""API endpoints for tenant configuration."""
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.configs.constants import DocumentSource
from onyx.db.connector import fetch_connectors
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.server.features.proposal_review.api.models import ConfigResponse
from onyx.server.features.proposal_review.api.models import ConfigUpdate
from onyx.server.features.proposal_review.api.models import JiraConnectorInfo
from onyx.server.features.proposal_review.db import config as config_db
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter()
@router.get("/config")
def get_config(
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> ConfigResponse:
"""Get the tenant's proposal review configuration."""
tenant_id = get_current_tenant_id()
config = config_db.get_config(tenant_id, db_session)
if not config:
# Return a default empty config rather than 404
config = config_db.upsert_config(tenant_id, db_session)
db_session.commit()
return ConfigResponse.from_model(config)
@router.put("/config")
def update_config(
request: ConfigUpdate,
_user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
db_session: Session = Depends(get_session),
) -> ConfigResponse:
"""Update the tenant's proposal review configuration."""
tenant_id = get_current_tenant_id()
config = config_db.upsert_config(
tenant_id=tenant_id,
jira_connector_id=request.jira_connector_id,
jira_project_key=request.jira_project_key,
field_mapping=request.field_mapping,
jira_writeback=request.jira_writeback,
review_model=request.review_model,
import_model=request.import_model,
db_session=db_session,
)
db_session.commit()
return ConfigResponse.from_model(config)
@router.get("/jira-connectors")
def list_jira_connectors(
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> list[JiraConnectorInfo]:
"""List all Jira connectors available to this tenant."""
connectors = fetch_connectors(db_session, sources=[DocumentSource.JIRA])
results: list[JiraConnectorInfo] = []
for c in connectors:
cfg = c.connector_specific_config or {}
project_key = cfg.get("project_key", "")
base_url = cfg.get("jira_base_url", "")
results.append(
JiraConnectorInfo(
id=c.id,
name=c.name,
project_key=project_key,
project_url=base_url,
)
)
return results
@router.get("/jira-connectors/{connector_id}/metadata-keys")
def get_connector_metadata_keys(
connector_id: int,
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> list[str]:
"""Return the distinct doc_metadata keys across all documents for a connector."""
return config_db.get_connector_metadata_keys(connector_id, db_session)

View File

@@ -1,147 +0,0 @@
"""API endpoints for per-finding decisions, proposal decisions, and Jira sync."""
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.features.proposal_review.api.models import FindingDecisionCreate
from onyx.server.features.proposal_review.api.models import FindingResponse
from onyx.server.features.proposal_review.api.models import JiraSyncResponse
from onyx.server.features.proposal_review.api.models import ProposalDecisionCreate
from onyx.server.features.proposal_review.api.models import ProposalDecisionResponse
from onyx.server.features.proposal_review.db import decisions as decisions_db
from onyx.server.features.proposal_review.db import findings as findings_db
from onyx.server.features.proposal_review.db import proposals as proposals_db
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter()
@router.post(
"/findings/{finding_id}/decision",
)
def record_finding_decision(
finding_id: UUID,
request: FindingDecisionCreate,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> FindingResponse:
"""Record or update a decision on a finding (upsert)."""
tenant_id = get_current_tenant_id()
# Verify finding exists
finding = findings_db.get_finding(finding_id, db_session)
if not finding:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Finding not found")
# Verify the finding's proposal belongs to the current tenant
proposal = proposals_db.get_proposal(finding.proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Finding not found")
finding = decisions_db.upsert_finding_decision(
finding_id=finding_id,
officer_id=user.id,
action=request.action,
notes=request.notes,
db_session=db_session,
)
db_session.commit()
return FindingResponse.from_model(finding)
@router.post(
"/proposals/{proposal_id}/decision",
status_code=201,
)
def record_proposal_decision(
proposal_id: UUID,
request: ProposalDecisionCreate,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> ProposalDecisionResponse:
"""Record a final decision on a proposal."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
# Validate decision value
valid_decisions = {"APPROVED", "CHANGES_REQUESTED", "REJECTED"}
if request.decision not in valid_decisions:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"decision must be APPROVED, CHANGES_REQUESTED, or REJECTED",
)
proposal = decisions_db.update_proposal_decision(
proposal_id=proposal_id,
tenant_id=tenant_id,
officer_id=user.id,
decision=request.decision,
notes=request.notes,
db_session=db_session,
)
db_session.commit()
return ProposalDecisionResponse.from_proposal(proposal)
@router.post(
"/proposals/{proposal_id}/sync-jira",
)
def sync_jira(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> JiraSyncResponse:
"""Sync the latest proposal decision to Jira.
Dispatches a Celery task that performs 3 Jira API operations:
1. Update custom fields (decision, completion %)
2. Transition the issue to the appropriate column
3. Post a structured review summary comment
"""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
if not proposal.decision_at:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"No decision to sync -- record a proposal decision first",
)
if proposal.jira_synced:
return JiraSyncResponse(
success=True,
message="Decision already synced to Jira",
)
# Dispatch Celery task via the client app (has Redis broker configured)
from onyx.background.celery.versioned_apps.client import app as celery_app
celery_app.send_task(
"sync_decision_to_jira",
args=[str(proposal_id), tenant_id],
expires=300,
)
db_session.commit()
return JiraSyncResponse(
success=True,
message="Jira sync task dispatched",
)

View File

@@ -1,483 +0,0 @@
"""Pydantic request/response models for Proposal Review."""
from datetime import datetime
from typing import Any
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
from onyx.server.features.proposal_review.db.models import ProposalReviewConfig
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
from onyx.server.features.proposal_review.db.models import ProposalReviewImportJob
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
from onyx.server.features.proposal_review.db.models import ProposalReviewRuleset
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
# =============================================================================
# Ruleset Schemas
# =============================================================================
class RulesetCreate(BaseModel):
name: str
description: str | None = None
is_default: bool = False
class RulesetUpdate(BaseModel):
name: str | None = None
description: str | None = None
is_default: bool | None = None
is_active: bool | None = None
class RulesetResponse(BaseModel):
id: UUID
tenant_id: str
name: str
description: str | None
is_default: bool
is_active: bool
created_by: UUID | None
created_at: datetime
updated_at: datetime
rules: list["RuleResponse"] = []
@classmethod
def from_model(
cls,
ruleset: ProposalReviewRuleset,
include_rules: bool = True,
) -> "RulesetResponse":
return cls(
id=ruleset.id,
tenant_id=ruleset.tenant_id,
name=ruleset.name,
description=ruleset.description,
is_default=ruleset.is_default,
is_active=ruleset.is_active,
created_by=ruleset.created_by,
created_at=ruleset.created_at,
updated_at=ruleset.updated_at,
rules=(
[RuleResponse.from_model(r) for r in ruleset.rules]
if include_rules
else []
),
)
# =============================================================================
# Rule Schemas
# =============================================================================
class RuleCreate(BaseModel):
name: str
description: str | None = None
category: str | None = None
rule_type: Literal[
"DOCUMENT_CHECK", "METADATA_CHECK", "CROSS_REFERENCE", "CUSTOM_NL"
]
rule_intent: Literal["CHECK", "HIGHLIGHT"] = "CHECK"
prompt_template: str
source: Literal["IMPORTED", "MANUAL"] = "MANUAL"
authority: Literal["OVERRIDE", "RETURN"] | None = None
is_hard_stop: bool = False
priority: int = 0
class RuleUpdate(BaseModel):
name: str | None = None
description: str | None = None
category: str | None = None
rule_type: str | None = None
rule_intent: str | None = None
prompt_template: str | None = None
authority: str | None = None
is_hard_stop: bool | None = None
priority: int | None = None
is_active: bool | None = None
refinement_needed: bool | None = None
refinement_question: str | None = None
class RuleRefinementRequest(BaseModel):
answer: str
class RuleResponse(BaseModel):
id: UUID
ruleset_id: UUID
name: str
description: str | None
category: str | None
rule_type: str
rule_intent: str
prompt_template: str
source: str
authority: str | None
is_hard_stop: bool
priority: int
is_active: bool
refinement_needed: bool
refinement_question: str | None
created_at: datetime
updated_at: datetime
@classmethod
def from_model(cls, rule: ProposalReviewRule) -> "RuleResponse":
return cls(
id=rule.id,
ruleset_id=rule.ruleset_id,
name=rule.name,
description=rule.description,
category=rule.category,
rule_type=rule.rule_type,
rule_intent=rule.rule_intent,
prompt_template=rule.prompt_template,
source=rule.source,
authority=rule.authority,
is_hard_stop=rule.is_hard_stop,
priority=rule.priority,
is_active=rule.is_active,
refinement_needed=rule.refinement_needed,
refinement_question=rule.refinement_question,
created_at=rule.created_at,
updated_at=rule.updated_at,
)
class BulkRuleUpdateRequest(BaseModel):
"""Batch activate/deactivate/delete rules."""
action: Literal["activate", "deactivate", "delete"]
rule_ids: list[UUID]
class BulkRuleUpdateResponse(BaseModel):
updated_count: int
class RuleTestResponse(BaseModel):
rule_id: str
success: bool
error: str | None = None
result: dict[str, Any] | None = None
# =============================================================================
# Proposal Schemas
# =============================================================================
class ProposalResponse(BaseModel):
"""Proposal response including inline decision fields."""
id: UUID
document_id: str
tenant_id: str
status: str
# Inline decision fields
decision_notes: str | None = None
decision_officer_id: UUID | None = None
decision_at: datetime | None = None
jira_synced: bool = False
jira_synced_at: datetime | None = None
created_at: datetime
updated_at: datetime
# Resolved metadata from Document table via field_mapping
metadata: dict[str, Any] = {}
@classmethod
def from_model(
cls,
proposal: ProposalReviewProposal,
metadata: dict[str, Any] | None = None,
) -> "ProposalResponse":
return cls(
id=proposal.id,
document_id=proposal.document_id,
tenant_id=proposal.tenant_id,
status=proposal.status,
decision_notes=proposal.decision_notes,
decision_officer_id=proposal.decision_officer_id,
decision_at=proposal.decision_at,
jira_synced=proposal.jira_synced,
jira_synced_at=proposal.jira_synced_at,
created_at=proposal.created_at,
updated_at=proposal.updated_at,
metadata=metadata or {},
)
class ProposalListResponse(BaseModel):
proposals: list[ProposalResponse]
total_count: int
config_missing: bool = False # True when no config exists
# =============================================================================
# Review Run Schemas
# =============================================================================
class ReviewRunTriggerRequest(BaseModel):
ruleset_id: UUID
class ReviewRunResponse(BaseModel):
id: UUID
proposal_id: UUID
ruleset_id: UUID
triggered_by: UUID
status: str
total_rules: int
completed_rules: int
failed_rules: int
started_at: datetime | None
completed_at: datetime | None
created_at: datetime
@classmethod
def from_model(cls, run: ProposalReviewRun) -> "ReviewRunResponse":
return cls(
id=run.id,
proposal_id=run.proposal_id,
ruleset_id=run.ruleset_id,
triggered_by=run.triggered_by,
status=run.status,
total_rules=run.total_rules,
completed_rules=run.completed_rules,
failed_rules=run.failed_rules,
started_at=run.started_at,
completed_at=run.completed_at,
created_at=run.created_at,
)
# =============================================================================
# Finding Schemas
# =============================================================================
class FindingResponse(BaseModel):
id: UUID
proposal_id: UUID
rule_id: UUID
review_run_id: UUID
verdict: str
confidence: str | None
evidence: str | None
explanation: str | None
suggested_action: str | None
llm_model: str | None
llm_tokens_used: int | None
created_at: datetime
# Nested rule info for display
rule_name: str | None = None
rule_category: str | None = None
rule_is_hard_stop: bool | None = None
# Inline decision fields
decision_action: str | None = None
decision_notes: str | None = None
decided_at: datetime | None = None
@classmethod
def from_model(cls, finding: ProposalReviewFinding) -> "FindingResponse":
rule_name = None
rule_category = None
rule_is_hard_stop = None
if finding.rule is not None:
rule_name = finding.rule.name
rule_category = finding.rule.category
rule_is_hard_stop = finding.rule.is_hard_stop
return cls(
id=finding.id,
proposal_id=finding.proposal_id,
rule_id=finding.rule_id,
review_run_id=finding.review_run_id,
verdict=finding.verdict,
confidence=finding.confidence,
evidence=finding.evidence,
explanation=finding.explanation,
suggested_action=finding.suggested_action,
llm_model=finding.llm_model,
llm_tokens_used=finding.llm_tokens_used,
created_at=finding.created_at,
rule_name=rule_name,
rule_category=rule_category,
rule_is_hard_stop=rule_is_hard_stop,
decision_action=finding.decision_action,
decision_notes=finding.decision_notes,
decided_at=finding.decided_at,
)
# =============================================================================
# Decision Schemas
# =============================================================================
class FindingDecisionCreate(BaseModel):
action: Literal["VERIFIED", "ISSUE", "NOT_APPLICABLE", "OVERRIDDEN"]
notes: str | None = None
class ProposalDecisionCreate(BaseModel):
decision: Literal["APPROVED", "CHANGES_REQUESTED", "REJECTED"]
notes: str | None = None
class ProposalDecisionResponse(BaseModel):
"""Response after recording a proposal-level decision."""
proposal_id: UUID
status: str
decision_notes: str | None
jira_synced: bool
decision_at: datetime | None
@classmethod
def from_proposal(
cls, proposal: ProposalReviewProposal
) -> "ProposalDecisionResponse":
return cls(
proposal_id=proposal.id,
status=proposal.status,
decision_notes=proposal.decision_notes,
jira_synced=proposal.jira_synced,
decision_at=proposal.decision_at,
)
# =============================================================================
# Config Schemas
# =============================================================================
class ConfigUpdate(BaseModel):
jira_connector_id: int | None = None
jira_project_key: str | None = None
field_mapping: list[str] | None = None # List of visible metadata keys
jira_writeback: dict[str, Any] | None = None
# LLM configuration
review_model: str | None = None # model name for rule evaluation
import_model: str | None = None # model name for checklist import
class ConfigResponse(BaseModel):
id: UUID
tenant_id: str
jira_connector_id: int | None
jira_project_key: str | None
field_mapping: list[str] | None
jira_writeback: dict[str, Any] | None
review_model: str | None
import_model: str | None
created_at: datetime
updated_at: datetime
@classmethod
def from_model(cls, config: ProposalReviewConfig) -> "ConfigResponse":
return cls(
id=config.id,
tenant_id=config.tenant_id,
jira_connector_id=config.jira_connector_id,
jira_project_key=config.jira_project_key,
field_mapping=config.field_mapping,
jira_writeback=config.jira_writeback,
review_model=config.review_model,
import_model=config.import_model,
created_at=config.created_at,
updated_at=config.updated_at,
)
# =============================================================================
# Import Schemas
# =============================================================================
class ImportResponse(BaseModel):
rules_created: int
rules: list[RuleResponse]
class ImportJobResponse(BaseModel):
id: UUID
status: str
source_filename: str
rules_created: int
error_message: str | None
created_at: datetime
completed_at: datetime | None
@classmethod
def from_model(cls, job: ProposalReviewImportJob) -> "ImportJobResponse":
return cls(
id=job.id,
status=job.status,
source_filename=job.source_filename,
rules_created=job.rules_created,
error_message=job.error_message,
created_at=job.created_at,
completed_at=job.completed_at,
)
# =============================================================================
# Document Schemas
# =============================================================================
class ProposalDocumentResponse(BaseModel):
id: UUID
proposal_id: UUID
file_name: str
file_type: str | None
document_role: str
uploaded_by: UUID | None
extracted_text: str | None = None
created_at: datetime
@classmethod
def from_model(cls, doc: ProposalReviewDocument) -> "ProposalDocumentResponse":
return cls(
id=doc.id,
proposal_id=doc.proposal_id,
file_name=doc.file_name,
file_type=doc.file_type,
document_role=doc.document_role,
uploaded_by=doc.uploaded_by,
extracted_text=getattr(doc, "extracted_text", None),
created_at=doc.created_at,
)
# =============================================================================
# Jira Sync Schemas
# =============================================================================
class JiraSyncResponse(BaseModel):
success: bool
message: str
# =============================================================================
# Jira Connector Discovery Schemas
# =============================================================================
class JiraConnectorInfo(BaseModel):
id: int
name: str
project_key: str
project_url: str

View File

@@ -1,367 +0,0 @@
"""API endpoints for proposals and proposal documents."""
import io
from typing import Any
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Form
from fastapi import UploadFile
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.configs.constants import DocumentSource
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import Connector
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.server.features.proposal_review.api.models import ProposalDocumentResponse
from onyx.server.features.proposal_review.api.models import ProposalListResponse
from onyx.server.features.proposal_review.api.models import ProposalResponse
from onyx.server.features.proposal_review.configs import (
DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES,
)
from onyx.server.features.proposal_review.db import config as config_db
from onyx.server.features.proposal_review.db import proposals as proposals_db
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter()
def _resolve_document_metadata(
document: Document,
visible_fields: list[str] | None,
) -> dict[str, Any]:
"""Resolve metadata from a Document's tags, filtered to visible fields.
Jira custom fields are stored as Tag rows (tag_key / tag_value)
linked to the document via document__tag. visible_fields selects
which tag keys to include. If None/empty, returns all tags.
"""
# Build metadata from the document's tags
raw_metadata: dict[str, Any] = {}
for tag in document.tags:
key = tag.tag_key
value = tag.tag_value
# Tags with is_list=True can have multiple values for the same key
if tag.is_list:
raw_metadata.setdefault(key, [])
raw_metadata[key].append(value)
else:
raw_metadata[key] = value
# Extract jira_key from tags and clean title from semantic_id.
# Jira semantic_id is "KEY-123: Summary Text" — split to isolate each.
jira_key = raw_metadata.get("key", "")
title = document.semantic_id or ""
if title and ": " in title:
title = title.split(": ", 1)[1]
raw_metadata["jira_key"] = jira_key
raw_metadata["title"] = title
raw_metadata["link"] = document.link
if not visible_fields:
return raw_metadata
# Filter to only the selected fields, plus always include core fields
# that the frontend needs for navigation, display, and filtering.
resolved: dict[str, Any] = {
"jira_key": raw_metadata.get("jira_key"),
"title": raw_metadata.get("title"),
"link": raw_metadata.get("link"),
"status": raw_metadata.get("status"),
}
for key in visible_fields:
if key in raw_metadata:
resolved[key] = raw_metadata[key]
return resolved
@router.get("/proposals")
def list_proposals(
status: str | None = None,
limit: int = 100,
offset: int = 0,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ProposalListResponse:
"""List proposals.
This queries the Document table filtered by the configured Jira project,
LEFT JOINs proposal_review_proposal for review state, and resolves
metadata field names via the field_mapping config.
Documents without a proposal record are returned with status PENDING
without persisting any new rows (read-only endpoint).
"""
tenant_id = get_current_tenant_id()
# Get config for field mapping and Jira project filtering
config = config_db.get_config(tenant_id, db_session)
# When no config exists, return an empty list with a hint for the frontend.
# The frontend can show "Configure a Jira connector in Settings to see proposals."
if config is None:
return ProposalListResponse(
proposals=[],
total_count=0,
config_missing=True,
)
visible_fields = config.field_mapping
# Query documents from the configured Jira connector only,
# LEFT JOIN proposal state for review tracking.
# NOTE: Tenant isolation is handled at the schema level (schema-per-tenant).
# The DB session is already scoped to the current tenant's schema, so
# cross-tenant data leakage is prevented by the connection itself.
query = (
db_session.query(Document, ProposalReviewProposal)
.outerjoin(
ProposalReviewProposal,
Document.id == ProposalReviewProposal.document_id,
)
.options(selectinload(Document.tags))
)
# Filter to only documents from the configured Jira connector
if config and config.jira_connector_id:
# Join through DocumentByConnectorCredentialPair to filter by connector
query = query.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
).filter(
DocumentByConnectorCredentialPair.connector_id == config.jira_connector_id,
)
else:
# No connector configured — filter to Jira source connectors only
# to avoid showing Slack/GitHub/etc documents
query = (
query.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
Connector,
DocumentByConnectorCredentialPair.connector_id == Connector.id,
)
.filter(
Connector.source == DocumentSource.JIRA,
)
)
# Exclude attachment documents — they are children of issue documents
# and have "/attachments/" in their document ID.
query = query.filter(~Document.id.contains("/attachments/"))
# If status filter is specified, only show documents with matching proposal status.
# PENDING is special: documents without a proposal record are implicitly pending.
if status:
if status == "PENDING":
query = query.filter(
or_(
ProposalReviewProposal.status == status,
ProposalReviewProposal.id.is_(None),
),
)
else:
query = query.filter(ProposalReviewProposal.status == status)
# Count before adding DISTINCT ON — count(distinct(...)) handles
# deduplication on its own and conflicts with DISTINCT ON.
total_count = (
query.with_entities(func.count(func.distinct(Document.id))).scalar() or 0
)
# Deduplicate rows that can arise from multiple connector-credential pairs.
# Applied after counting to avoid the DISTINCT ON + aggregate conflict.
# ORDER BY Document.id is required for DISTINCT ON to be deterministic.
query = query.distinct(Document.id).order_by(Document.id)
results = query.offset(offset).limit(limit).all()
proposals: list[ProposalResponse] = []
created_any = False
for document, proposal in results:
if proposal is None:
# Lazily create the proposal record so the frontend gets a
# stable UUID it can use for navigation and subsequent API calls.
proposal = proposals_db.get_or_create_proposal(
document_id=document.id,
tenant_id=tenant_id,
db_session=db_session,
)
created_any = True
metadata = _resolve_document_metadata(document, visible_fields)
proposals.append(ProposalResponse.from_model(proposal, metadata=metadata))
if created_any:
db_session.commit()
return ProposalListResponse(proposals=proposals, total_count=total_count)
@router.get("/proposals/{proposal_id}")
def get_proposal(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ProposalResponse:
"""Get a single proposal with its metadata from the Document table."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
# Load the linked Document for metadata
document = (
db_session.query(Document)
.options(selectinload(Document.tags))
.filter(Document.id == proposal.document_id)
.one_or_none()
)
config = config_db.get_config(tenant_id, db_session)
visible_fields = config.field_mapping if config else None
metadata: dict[str, Any] = {}
if document:
metadata = _resolve_document_metadata(document, visible_fields)
return ProposalResponse.from_model(proposal, metadata=metadata)
# =============================================================================
# Proposal Documents (manual uploads)
# =============================================================================
@router.post(
"/proposals/{proposal_id}/documents",
status_code=201,
)
def upload_document(
proposal_id: UUID,
file: UploadFile,
document_role: str = Form("OTHER"),
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> ProposalDocumentResponse:
"""Upload a document to a proposal."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
# Read file content
try:
file_bytes = file.file.read()
except Exception as e:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Failed to read uploaded file: {str(e)}",
)
# Validate file size
if len(file_bytes) > DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES:
raise OnyxError(
OnyxErrorCode.PAYLOAD_TOO_LARGE,
f"File size {len(file_bytes)} bytes exceeds maximum "
f"allowed size of {DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES} bytes",
)
# Determine file type from filename
filename = file.filename or "untitled"
file_type = None
if filename:
parts = filename.rsplit(".", 1)
if len(parts) > 1:
file_type = parts[1].upper()
# Extract text from the uploaded file
extracted_text = None
if file_bytes:
try:
extracted_text = extract_file_text(
file=io.BytesIO(file_bytes),
file_name=filename,
)
except Exception as e:
logger.warning(
f"Failed to extract text from uploaded file '{filename}': {e}"
)
doc = ProposalReviewDocument(
proposal_id=proposal_id,
file_name=filename,
file_type=file_type,
document_role=document_role,
uploaded_by=user.id,
extracted_text=extracted_text,
)
db_session.add(doc)
db_session.commit()
return ProposalDocumentResponse.from_model(doc)
@router.get("/proposals/{proposal_id}/documents")
def list_documents(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> list[ProposalDocumentResponse]:
"""List documents for a proposal."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
docs = (
db_session.query(ProposalReviewDocument)
.filter(ProposalReviewDocument.proposal_id == proposal_id)
.order_by(ProposalReviewDocument.created_at)
.all()
)
return [ProposalDocumentResponse.from_model(d) for d in docs]
@router.delete("/proposals/{proposal_id}/documents/{doc_id}", status_code=204)
def delete_document(
proposal_id: UUID,
doc_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> None:
"""Delete a manually uploaded document."""
# Verify the proposal belongs to the current tenant
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
doc = (
db_session.query(ProposalReviewDocument)
.filter(
ProposalReviewDocument.id == doc_id,
ProposalReviewDocument.proposal_id == proposal_id,
)
.one_or_none()
)
if not doc:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Document not found")
db_session.delete(doc)
db_session.commit()

View File

@@ -1,211 +0,0 @@
"""API endpoints for review triggers, status, and findings."""
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.features.proposal_review.api.models import FindingResponse
from onyx.server.features.proposal_review.api.models import ReviewRunResponse
from onyx.server.features.proposal_review.api.models import ReviewRunTriggerRequest
from onyx.server.features.proposal_review.db import findings as findings_db
from onyx.server.features.proposal_review.db import proposals as proposals_db
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter()
@router.post(
"/proposals/{proposal_id}/review",
status_code=201,
)
def trigger_review(
proposal_id: UUID,
request: ReviewRunTriggerRequest,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
db_session: Session = Depends(get_session),
) -> ReviewRunResponse:
"""Trigger a new review run for a proposal."""
tenant_id = get_current_tenant_id()
# Verify proposal exists
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
# Verify ruleset exists and count active rules
ruleset = rulesets_db.get_ruleset(request.ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
active_rule_count = rulesets_db.count_active_rules(request.ruleset_id, db_session)
if active_rule_count == 0:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Ruleset has no active rules",
)
# Update proposal status to IN_REVIEW
proposals_db.update_proposal_status(proposal_id, tenant_id, "IN_REVIEW", db_session)
# Create the review run record
run = findings_db.create_review_run(
proposal_id=proposal_id,
ruleset_id=request.ruleset_id,
triggered_by=user.id,
total_rules=active_rule_count,
db_session=db_session,
)
db_session.commit()
logger.info(
f"Review triggered for proposal {proposal_id} "
f"with ruleset {request.ruleset_id} ({active_rule_count} rules)"
)
# Dispatch Celery task via the client app (has Redis broker configured)
from onyx.background.celery.versioned_apps.client import app as celery_app
celery_app.send_task(
"run_proposal_review",
args=[str(run.id), tenant_id],
expires=3600,
)
return ReviewRunResponse.from_model(run)
@router.get(
"/proposals/{proposal_id}/review-runs",
)
def list_review_runs(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> list[ReviewRunResponse]:
"""List all review runs for a proposal, most recent first."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
runs = findings_db.list_review_runs_by_proposal(proposal_id, db_session)
return [ReviewRunResponse.from_model(r) for r in runs]
@router.get(
"/proposals/{proposal_id}/review-status",
)
def get_review_status(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ReviewRunResponse:
"""Get the status of the latest review run for a proposal."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
run = findings_db.get_latest_review_run(proposal_id, db_session)
if not run:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No review runs found")
return ReviewRunResponse.from_model(run)
@router.post(
"/proposals/{proposal_id}/retry-failed",
status_code=200,
)
def retry_failed_rules_endpoint(
proposal_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ReviewRunResponse:
"""Retry only the rules that failed in the latest review run."""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
run = findings_db.get_latest_review_run(proposal_id, db_session)
if not run:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No review runs found")
if run.status not in ("COMPLETED", "FAILED"):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Cannot retry: review is still running",
)
failed = findings_db.get_failed_findings_for_run(run.id, db_session)
if not failed:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"No failed rules to retry",
)
rule_ids = list({str(f.rule_id) for f in failed})
# Set status to RUNNING before dispatching so a second call is rejected
run.status = "RUNNING"
run.completed_at = None
db_session.commit()
logger.info(
f"Retrying {len(rule_ids)} failed rules for run {run.id} "
f"on proposal {proposal_id}"
)
from onyx.background.celery.versioned_apps.client import app as celery_app
celery_app.send_task(
"run_proposal_review",
args=[str(run.id), tenant_id],
kwargs={"rule_ids": rule_ids},
expires=3600,
)
return ReviewRunResponse.from_model(run)
@router.get(
"/proposals/{proposal_id}/findings",
)
def get_findings(
proposal_id: UUID,
review_run_id: UUID | None = None,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> list[FindingResponse]:
"""Get findings for a proposal.
If review_run_id is not specified, returns findings from the latest run.
"""
tenant_id = get_current_tenant_id()
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
# If no run specified, get the latest
if review_run_id is None:
run = findings_db.get_latest_review_run(proposal_id, db_session)
if not run:
return []
review_run_id = run.id
results = findings_db.list_findings_by_run(review_run_id, db_session)
return [FindingResponse.from_model(f) for f in results]

View File

@@ -1,545 +0,0 @@
"""API endpoints for rulesets and rules."""
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Form
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.auth.permissions import require_permission
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.features.proposal_review.api.models import BulkRuleUpdateRequest
from onyx.server.features.proposal_review.api.models import BulkRuleUpdateResponse
from onyx.server.features.proposal_review.api.models import ImportJobResponse
from onyx.server.features.proposal_review.api.models import RuleCreate
from onyx.server.features.proposal_review.api.models import RuleResponse
from onyx.server.features.proposal_review.api.models import RulesetCreate
from onyx.server.features.proposal_review.api.models import RulesetResponse
from onyx.server.features.proposal_review.api.models import RulesetUpdate
from onyx.server.features.proposal_review.api.models import RuleTestResponse
from onyx.server.features.proposal_review.api.models import RuleUpdate
from onyx.server.features.proposal_review.configs import IMPORT_MAX_FILE_SIZE_BYTES
from onyx.server.features.proposal_review.db import imports as imports_db
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter()
# =============================================================================
# Rulesets
# =============================================================================
@router.get("/rulesets")
def list_rulesets(
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> list[RulesetResponse]:
"""List all rulesets for the current tenant."""
tenant_id = get_current_tenant_id()
rulesets = rulesets_db.list_rulesets(tenant_id, db_session)
return [RulesetResponse.from_model(rs) for rs in rulesets]
@router.post("/rulesets", status_code=201)
def create_ruleset(
request: RulesetCreate,
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
db_session: Session = Depends(get_session),
) -> RulesetResponse:
"""Create a new ruleset."""
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.create_ruleset(
tenant_id=tenant_id,
name=request.name,
description=request.description,
is_default=request.is_default,
created_by=user.id,
db_session=db_session,
)
db_session.commit()
return RulesetResponse.from_model(ruleset, include_rules=False)
@router.get("/rulesets/{ruleset_id}")
def get_ruleset(
ruleset_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> RulesetResponse:
"""Get a ruleset with all its rules."""
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
return RulesetResponse.from_model(ruleset)
@router.put("/rulesets/{ruleset_id}")
def update_ruleset(
ruleset_id: UUID,
request: RulesetUpdate,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> RulesetResponse:
"""Update a ruleset."""
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.update_ruleset(
ruleset_id=ruleset_id,
tenant_id=tenant_id,
db_session=db_session,
updates=request.model_dump(exclude_unset=True),
)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
db_session.commit()
return RulesetResponse.from_model(ruleset)
@router.delete("/rulesets/{ruleset_id}", status_code=204)
def delete_ruleset(
ruleset_id: UUID,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> None:
"""Delete a ruleset and all its rules."""
tenant_id = get_current_tenant_id()
deleted = rulesets_db.delete_ruleset(ruleset_id, tenant_id, db_session)
if not deleted:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
db_session.commit()
# =============================================================================
# Rules
# =============================================================================
@router.post(
"/rulesets/{ruleset_id}/rules",
status_code=201,
)
def create_rule(
ruleset_id: UUID,
request: RuleCreate,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> RuleResponse:
"""Create a new rule within a ruleset."""
# Verify ruleset exists and belongs to tenant
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
rule = rulesets_db.create_rule(
ruleset_id=ruleset_id,
name=request.name,
description=request.description,
category=request.category,
rule_type=request.rule_type,
rule_intent=request.rule_intent,
prompt_template=request.prompt_template,
source=request.source,
authority=request.authority,
is_hard_stop=request.is_hard_stop,
priority=request.priority,
db_session=db_session,
)
db_session.commit()
return RuleResponse.from_model(rule)
@router.put("/rules/{rule_id}")
def update_rule(
rule_id: UUID,
request: RuleUpdate,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> RuleResponse:
"""Update a rule."""
# Verify the rule belongs to a ruleset owned by the current tenant
tenant_id = get_current_tenant_id()
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
if not rule:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
updated_rule = rulesets_db.update_rule(
rule_id=rule_id,
db_session=db_session,
updates=request.model_dump(exclude_unset=True),
)
if not updated_rule:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
db_session.commit()
return RuleResponse.from_model(updated_rule)
@router.delete("/rules/{rule_id}", status_code=204)
def delete_rule(
rule_id: UUID,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> None:
"""Delete a rule."""
# Verify the rule belongs to a ruleset owned by the current tenant
tenant_id = get_current_tenant_id()
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
if not rule:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
deleted = rulesets_db.delete_rule(rule_id, db_session)
if not deleted:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
db_session.commit()
@router.post(
"/rulesets/{ruleset_id}/rules/bulk-update",
)
def bulk_update_rules(
ruleset_id: UUID,
request: BulkRuleUpdateRequest,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> BulkRuleUpdateResponse:
"""Batch activate/deactivate/delete rules."""
# Verify the ruleset belongs to the current tenant
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
if request.action not in ("activate", "deactivate", "delete"):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"action must be 'activate', 'deactivate', or 'delete'",
)
# Only operate on rules that belong to this ruleset (tenant-scoped)
count = rulesets_db.bulk_update_rules(
request.rule_ids, request.action, ruleset_id, db_session
)
db_session.commit()
return BulkRuleUpdateResponse(updated_count=count)
@router.post(
"/rulesets/{ruleset_id}/import",
status_code=202,
)
def import_checklist_endpoint(
ruleset_id: UUID,
file: UploadFile,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Upload a checklist document and parse it into rules via LLM.
Text extraction happens synchronously (fast). The LLM decomposition
runs in a Celery task so the request returns 202 immediately.
Poll GET /rulesets/{ruleset_id}/import/{import_job_id}/status
to track progress.
"""
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
# Read the uploaded file content (synchronous -- fast)
try:
file_content = file.file.read()
except Exception as e:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Failed to read uploaded file: {str(e)}",
)
if not file_content:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Uploaded file is empty")
if len(file_content) > IMPORT_MAX_FILE_SIZE_BYTES:
raise OnyxError(
OnyxErrorCode.PAYLOAD_TOO_LARGE,
f"File size {len(file_content)} bytes exceeds maximum "
f"allowed size of {IMPORT_MAX_FILE_SIZE_BYTES} bytes",
)
# Extract text synchronously (fast -- no LLM involved)
extracted_text = ""
filename = file.filename or "untitled"
file_ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if file_ext in ("txt", "text", "md"):
extracted_text = file_content.decode("utf-8", errors="replace")
else:
try:
import io
from onyx.file_processing.extract_file_text import extract_file_text
extracted_text = extract_file_text(
file=io.BytesIO(file_content),
file_name=filename,
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Failed to extract text from file: {str(e)}",
)
if not extracted_text or not extracted_text.strip():
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"No text could be extracted from the uploaded file",
)
# Create the import job row
job = imports_db.create_import_job(
ruleset_id=ruleset_id,
tenant_id=tenant_id,
source_filename=filename,
extracted_text=extracted_text,
db_session=db_session,
)
db_session.commit()
# Dispatch Celery task via the client app (has Redis broker configured)
from onyx.background.celery.versioned_apps.client import app as celery_app
celery_app.send_task(
"run_checklist_import",
args=[str(job.id), tenant_id],
expires=600,
)
return {"import_job_id": str(job.id)}
@router.get(
"/rulesets/{ruleset_id}/import/active",
)
def get_active_import_job(
ruleset_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ImportJobResponse | None:
"""Get the latest active (PENDING/RUNNING) import job for a ruleset."""
tenant_id = get_current_tenant_id()
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
job = imports_db.get_active_import_job(ruleset_id, db_session)
if not job:
return None
return ImportJobResponse.from_model(job)
@router.get(
"/rulesets/{ruleset_id}/import/{import_job_id}/status",
)
def get_import_job_status(
ruleset_id: UUID,
import_job_id: UUID,
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
db_session: Session = Depends(get_session),
) -> ImportJobResponse:
"""Get the status of a checklist import job."""
tenant_id = get_current_tenant_id()
# Verify ruleset belongs to tenant
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
job = imports_db.get_import_job(import_job_id, db_session)
if not job or job.ruleset_id != ruleset_id:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Import job not found")
return ImportJobResponse.from_model(job)
@router.post("/rules/{rule_id}/test")
def test_rule(
rule_id: UUID,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> RuleTestResponse:
"""Test a rule against sample text.
Evaluates the rule against an empty/minimal proposal context to verify
the prompt template is well-formed and the LLM can produce a valid response.
"""
# Verify the rule belongs to a ruleset owned by the current tenant
tenant_id = get_current_tenant_id()
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
if not rule:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
from onyx.server.features.proposal_review.engine.context_assembler import (
ProposalContext,
)
from onyx.server.features.proposal_review.engine.rule_evaluator import (
evaluate_rule,
)
# Build a minimal test context
test_context = ProposalContext(
proposal_text="[Sample proposal text for testing. No real proposal loaded.]",
budget_text="[No budget text available for test.]",
foa_text="[No FOA text available for test.]",
metadata={"test_mode": True},
jira_key="TEST-000",
)
try:
result = evaluate_rule(rule, test_context, db_session)
except Exception as e:
return RuleTestResponse(
rule_id=str(rule_id),
success=False,
error=str(e),
)
return RuleTestResponse(
rule_id=str(rule_id),
success=True,
result=result,
)
@router.post("/rules/{rule_id}/refine")
async def refine_rule_endpoint(
rule_id: UUID,
answer: str = Form(...),
file: UploadFile | None = None,
user: User = Depends( # noqa: ARG001
require_permission(Permission.MANAGE_CONNECTORS)
),
db_session: Session = Depends(get_session),
) -> RuleResponse:
"""Submit an answer to a rule's refinement question.
Re-runs the LLM to produce a refined prompt_template that incorporates
the user's institution-specific information, then clears the refinement
flag on the rule. An optional file attachment (pdf, docx, etc.) can be
included — its extracted text is appended to the answer.
"""
tenant_id = get_current_tenant_id()
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
if not rule:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
if not rule.refinement_needed:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"This rule does not need refinement",
)
if not rule.refinement_question:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Rule is marked for refinement but has no refinement question",
)
# Build the combined answer from the text field + optional attachment
combined_answer = answer.strip()
if file and file.filename:
import io
file_content = await file.read()
if file_content:
filename = file.filename
file_ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if file_ext in ("txt", "text", "md"):
file_text = file_content.decode("utf-8", errors="replace")
else:
try:
from onyx.file_processing.extract_file_text import (
extract_file_text,
)
file_text = extract_file_text(
file=io.BytesIO(file_content),
file_name=filename,
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Failed to extract text from file: {str(e)}",
)
if file_text and file_text.strip():
combined_answer += (
f"\n\n--- Attached file: {filename} ---\n{file_text.strip()}"
)
from onyx.llm.factory import get_default_llm
from onyx.server.features.proposal_review.engine.checklist_importer import (
refine_rule,
)
llm = get_default_llm(timeout=120)
try:
refined = refine_rule(
rule_name=rule.name,
rule_description=rule.description,
rule_prompt_template=rule.prompt_template,
refinement_question=rule.refinement_question,
user_answer=combined_answer,
llm=llm,
)
except RuntimeError as e:
raise OnyxError(
OnyxErrorCode.SERVER_ERROR,
f"Refinement failed: {str(e)}",
)
updated = rulesets_db.update_rule(
rule_id=rule_id,
db_session=db_session,
updates={
"name": refined["name"],
"description": refined.get("description"),
"prompt_template": refined["prompt_template"],
"rule_type": refined.get("rule_type", rule.rule_type),
"rule_intent": refined.get("rule_intent", rule.rule_intent),
"refinement_needed": False,
"refinement_question": None,
},
)
if not updated:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
db_session.commit()
return RuleResponse.from_model(updated)

View File

@@ -1,18 +0,0 @@
import os
# Feature flag for enabling proposal review
ENABLE_PROPOSAL_REVIEW = (
os.environ.get("ENABLE_PROPOSAL_REVIEW", "true").lower() == "true"
)
# Maximum file size for checklist imports (in MB)
IMPORT_MAX_FILE_SIZE_MB = int(
os.environ.get("PROPOSAL_REVIEW_IMPORT_MAX_FILE_SIZE_MB", "50")
)
IMPORT_MAX_FILE_SIZE_BYTES = IMPORT_MAX_FILE_SIZE_MB * 1024 * 1024
# Maximum file size for document uploads (in MB)
DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB = int(
os.environ.get("PROPOSAL_REVIEW_DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB", "100")
)
DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES = DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB * 1024 * 1024

View File

@@ -1,101 +0,0 @@
"""DB operations for tenant configuration."""
from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Document__Tag
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import Tag
from onyx.server.features.proposal_review.db.models import ProposalReviewConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_config(
tenant_id: str,
db_session: Session,
) -> ProposalReviewConfig | None:
"""Get the config row for a tenant (there is at most one)."""
return (
db_session.query(ProposalReviewConfig)
.filter(ProposalReviewConfig.tenant_id == tenant_id)
.one_or_none()
)
def upsert_config(
tenant_id: str,
db_session: Session,
jira_connector_id: int | None = None,
jira_project_key: str | None = None,
field_mapping: list[str] | None = None,
jira_writeback: dict[str, Any] | None = None,
review_model: str | None = None,
import_model: str | None = None,
) -> ProposalReviewConfig:
"""Create or update the tenant config."""
config = get_config(tenant_id, db_session)
if config:
if jira_connector_id is not None:
config.jira_connector_id = jira_connector_id
if jira_project_key is not None:
config.jira_project_key = jira_project_key
if field_mapping is not None:
config.field_mapping = field_mapping
if jira_writeback is not None:
config.jira_writeback = jira_writeback
if review_model is not None:
config.review_model = review_model
if import_model is not None:
config.import_model = import_model
config.updated_at = datetime.now(timezone.utc)
db_session.flush()
logger.info(f"Updated proposal review config for tenant {tenant_id}")
return config
config = ProposalReviewConfig(
tenant_id=tenant_id,
jira_connector_id=jira_connector_id,
jira_project_key=jira_project_key,
field_mapping=field_mapping,
jira_writeback=jira_writeback,
review_model=review_model,
import_model=import_model,
)
db_session.add(config)
db_session.flush()
logger.info(f"Created proposal review config for tenant {tenant_id}")
return config
def get_connector_metadata_keys(
connector_id: int,
db_session: Session,
) -> list[str]:
"""Return distinct metadata tag keys for documents from a connector.
Jira custom fields are stored as tags (tag_key / tag_value) linked
to documents via the document__tag join table.
"""
stmt = (
select(Tag.tag_key)
.select_from(Tag)
.join(Document__Tag, Tag.id == Document__Tag.tag_id)
.join(
DocumentByConnectorCredentialPair,
Document__Tag.document_id == DocumentByConnectorCredentialPair.id,
)
.where(
DocumentByConnectorCredentialPair.connector_id == connector_id,
)
.distinct()
.limit(500)
)
rows = db_session.execute(stmt).all()
return sorted(row[0] for row in rows)

View File

@@ -1,115 +0,0 @@
"""DB operations for finding decisions and proposal decisions.
Finding decisions are stored inline on the ProposalReviewFinding row.
Proposal decisions are stored inline on the ProposalReviewProposal row.
"""
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.utils.logger import setup_logger
logger = setup_logger()
# =============================================================================
# Per-Finding Decisions (inline on finding row)
# =============================================================================
def upsert_finding_decision(
finding_id: UUID,
officer_id: UUID,
action: str,
db_session: Session,
notes: str | None = None,
) -> ProposalReviewFinding:
"""Record or update a decision on a finding.
The decision fields live directly on the finding row.
"""
finding = (
db_session.query(ProposalReviewFinding)
.filter(ProposalReviewFinding.id == finding_id)
.one_or_none()
)
if not finding:
raise ValueError(f"Finding {finding_id} not found")
finding.decision_action = action
finding.decision_notes = notes
finding.decision_officer_id = officer_id
finding.decided_at = datetime.now(timezone.utc)
db_session.flush()
logger.info(f"Recorded decision on finding {finding_id}: {action}")
return finding
# =============================================================================
# Proposal-Level Decisions (inline on proposal row)
# =============================================================================
def update_proposal_decision(
proposal_id: UUID,
tenant_id: str,
officer_id: UUID,
decision: str,
db_session: Session,
notes: str | None = None,
) -> ProposalReviewProposal:
"""Record a final decision on a proposal.
Overwrites previous decision fields on the proposal row.
"""
proposal = (
db_session.query(ProposalReviewProposal)
.filter(
ProposalReviewProposal.id == proposal_id,
ProposalReviewProposal.tenant_id == tenant_id,
)
.one_or_none()
)
if not proposal:
raise ValueError(f"Proposal {proposal_id} not found")
proposal.status = decision
proposal.decision_notes = notes
proposal.decision_officer_id = officer_id
proposal.decision_at = datetime.now(timezone.utc)
proposal.jira_synced = False
proposal.jira_synced_at = None
proposal.updated_at = datetime.now(timezone.utc)
db_session.flush()
logger.info(f"Recorded proposal decision {decision} for proposal {proposal_id}")
return proposal
def mark_proposal_jira_synced(
proposal_id: UUID,
tenant_id: str,
db_session: Session,
) -> ProposalReviewProposal | None:
"""Mark a proposal's decision as synced to Jira."""
proposal = (
db_session.query(ProposalReviewProposal)
.filter(
ProposalReviewProposal.id == proposal_id,
ProposalReviewProposal.tenant_id == tenant_id,
)
.one_or_none()
)
if not proposal:
return None
proposal.jira_synced = True
proposal.jira_synced_at = datetime.now(timezone.utc)
db_session.flush()
logger.info(f"Marked proposal {proposal_id} as jira_synced")
return proposal

View File

@@ -1,206 +0,0 @@
"""DB operations for review runs and findings."""
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
from onyx.utils.logger import setup_logger
logger = setup_logger()
# =============================================================================
# Review Runs
# =============================================================================
def create_review_run(
proposal_id: UUID,
ruleset_id: UUID,
triggered_by: UUID,
total_rules: int,
db_session: Session,
) -> ProposalReviewRun:
"""Create a new review run record."""
run = ProposalReviewRun(
proposal_id=proposal_id,
ruleset_id=ruleset_id,
triggered_by=triggered_by,
total_rules=total_rules,
)
db_session.add(run)
db_session.flush()
logger.info(
f"Created review run {run.id} for proposal {proposal_id} "
f"with {total_rules} rules"
)
return run
def get_review_run(
run_id: UUID,
db_session: Session,
) -> ProposalReviewRun | None:
"""Get a review run by ID."""
return (
db_session.query(ProposalReviewRun)
.filter(ProposalReviewRun.id == run_id)
.one_or_none()
)
def get_latest_review_run(
proposal_id: UUID,
db_session: Session,
) -> ProposalReviewRun | None:
"""Get the most recent review run for a proposal."""
return (
db_session.query(ProposalReviewRun)
.filter(ProposalReviewRun.proposal_id == proposal_id)
.order_by(desc(ProposalReviewRun.created_at))
.first()
)
def list_review_runs_by_proposal(
proposal_id: UUID,
db_session: Session,
limit: int = 20,
) -> list[ProposalReviewRun]:
"""List review runs for a proposal, most recent first."""
return (
db_session.query(ProposalReviewRun)
.filter(ProposalReviewRun.proposal_id == proposal_id)
.order_by(desc(ProposalReviewRun.created_at))
.limit(limit)
.all()
)
# =============================================================================
# Findings
# =============================================================================
def create_finding(
proposal_id: UUID,
rule_id: UUID,
review_run_id: UUID,
verdict: str,
db_session: Session,
confidence: str | None = None,
evidence: str | None = None,
explanation: str | None = None,
suggested_action: str | None = None,
llm_model: str | None = None,
llm_tokens_used: int | None = None,
) -> ProposalReviewFinding:
"""Create a new finding."""
finding = ProposalReviewFinding(
proposal_id=proposal_id,
rule_id=rule_id,
review_run_id=review_run_id,
verdict=verdict,
confidence=confidence,
evidence=evidence,
explanation=explanation,
suggested_action=suggested_action,
llm_model=llm_model,
llm_tokens_used=llm_tokens_used,
)
db_session.add(finding)
db_session.flush()
logger.info(
f"Created finding {finding.id} verdict={verdict} for proposal {proposal_id}"
)
return finding
def get_finding(
finding_id: UUID,
db_session: Session,
) -> ProposalReviewFinding | None:
"""Get a finding by ID with its rule eagerly loaded."""
return (
db_session.query(ProposalReviewFinding)
.filter(ProposalReviewFinding.id == finding_id)
.options(
selectinload(ProposalReviewFinding.rule),
)
.one_or_none()
)
def list_findings_by_proposal(
proposal_id: UUID,
db_session: Session,
review_run_id: UUID | None = None,
) -> list[ProposalReviewFinding]:
"""List findings for a proposal, optionally filtered to a specific run."""
query = (
db_session.query(ProposalReviewFinding)
.filter(ProposalReviewFinding.proposal_id == proposal_id)
.options(
selectinload(ProposalReviewFinding.rule),
)
.order_by(ProposalReviewFinding.created_at)
)
if review_run_id:
query = query.filter(ProposalReviewFinding.review_run_id == review_run_id)
return query.all()
def list_findings_by_run(
review_run_id: UUID,
db_session: Session,
) -> list[ProposalReviewFinding]:
"""List all findings for a specific review run."""
return (
db_session.query(ProposalReviewFinding)
.filter(ProposalReviewFinding.review_run_id == review_run_id)
.options(
selectinload(ProposalReviewFinding.rule),
)
.order_by(ProposalReviewFinding.created_at)
.all()
)
def get_failed_findings_for_run(
review_run_id: UUID,
db_session: Session,
) -> list[ProposalReviewFinding]:
"""Get findings that failed due to system errors (LLM timeout, etc.).
Error findings are created by _save_error_finding and are identifiable
by having no LLM metadata (the call never completed successfully).
"""
return (
db_session.query(ProposalReviewFinding)
.filter(
ProposalReviewFinding.review_run_id == review_run_id,
ProposalReviewFinding.verdict == "NEEDS_REVIEW",
ProposalReviewFinding.llm_model.is_(None),
ProposalReviewFinding.llm_tokens_used.is_(None),
)
.all()
)
def delete_findings(
finding_ids: list[UUID],
db_session: Session,
) -> int:
"""Delete findings by ID. Returns the number deleted."""
if not finding_ids:
return 0
count = (
db_session.query(ProposalReviewFinding)
.filter(ProposalReviewFinding.id.in_(finding_ids))
.delete(synchronize_session="fetch")
)
return count

View File

@@ -1,97 +0,0 @@
"""DB operations for checklist import jobs."""
import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.server.features.proposal_review.db.models import ProposalReviewImportJob
from onyx.utils.logger import setup_logger
logger = setup_logger()
def create_import_job(
ruleset_id: UUID,
tenant_id: str,
source_filename: str,
extracted_text: str,
db_session: Session,
) -> ProposalReviewImportJob:
"""Create a new import job record."""
job = ProposalReviewImportJob(
ruleset_id=ruleset_id,
tenant_id=tenant_id,
source_filename=source_filename,
extracted_text=extracted_text,
)
db_session.add(job)
db_session.flush()
logger.info(
f"Created import job {job.id} for ruleset {ruleset_id} "
f"(file: {source_filename})"
)
return job
def get_import_job(
job_id: UUID,
db_session: Session,
) -> ProposalReviewImportJob | None:
"""Get a single import job by ID."""
return (
db_session.query(ProposalReviewImportJob)
.filter(ProposalReviewImportJob.id == job_id)
.one_or_none()
)
def get_active_import_job(
ruleset_id: UUID,
db_session: Session,
) -> ProposalReviewImportJob | None:
"""Get the latest PENDING or RUNNING import job for a ruleset, if any."""
return (
db_session.query(ProposalReviewImportJob)
.filter(
ProposalReviewImportJob.ruleset_id == ruleset_id,
ProposalReviewImportJob.status.in_(["PENDING", "RUNNING"]),
)
.order_by(ProposalReviewImportJob.created_at.desc())
.first()
)
def get_dangling_import_jobs(
db_session: Session,
stale_threshold_minutes: int = 30,
) -> list[ProposalReviewImportJob]:
"""Return import jobs stuck in PENDING or RUNNING for longer than the threshold."""
cutoff = datetime.datetime.now(timezone.utc) - datetime.timedelta(
minutes=stale_threshold_minutes
)
return (
db_session.query(ProposalReviewImportJob)
.filter(
ProposalReviewImportJob.status.in_(["PENDING", "RUNNING"]),
ProposalReviewImportJob.created_at < cutoff,
)
.all()
)
def mark_import_job_failed(
job: ProposalReviewImportJob,
error_message: str,
db_session: Session,
) -> None:
"""Mark an import job as FAILED with the given error message.
Flushes but does NOT commit — the caller is responsible for committing
so that batch operations can be done in a single transaction.
"""
job.status = "FAILED"
job.error_message = error_message
job.completed_at = datetime.datetime.now(timezone.utc)
db_session.flush()

View File

@@ -1,367 +0,0 @@
"""SQLAlchemy models for Proposal Review."""
import datetime
from uuid import UUID
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB as PGJSONB
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from onyx.db.models import Base
class ProposalReviewRuleset(Base):
__tablename__ = "proposal_review_ruleset"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
tenant_id: Mapped[str] = mapped_column(Text, nullable=False, index=True)
name: Mapped[str] = mapped_column(Text, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_default: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
is_active: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
created_by: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
rules: Mapped[list["ProposalReviewRule"]] = relationship(
"ProposalReviewRule",
back_populates="ruleset",
cascade="all, delete-orphan",
order_by="ProposalReviewRule.priority",
)
class ProposalReviewRule(Base):
__tablename__ = "proposal_review_rule"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
ruleset_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_ruleset.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(Text, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
category: Mapped[str | None] = mapped_column(Text, nullable=True)
rule_type: Mapped[str] = mapped_column(Text, nullable=False)
rule_intent: Mapped[str] = mapped_column(
Text, nullable=False, server_default=text("'CHECK'")
)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False)
source: Mapped[str] = mapped_column(
Text, nullable=False, server_default=text("'MANUAL'")
)
authority: Mapped[str | None] = mapped_column(Text, nullable=True)
is_hard_stop: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
priority: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
is_active: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
refinement_needed: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
refinement_question: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
ruleset: Mapped["ProposalReviewRuleset"] = relationship(
"ProposalReviewRuleset", back_populates="rules"
)
class ProposalReviewProposal(Base):
__tablename__ = "proposal_review_proposal"
__table_args__ = (
UniqueConstraint("document_id", "tenant_id"),
Index("ix_proposal_review_proposal_tenant_id", "tenant_id"),
Index("ix_proposal_review_proposal_document_id", "document_id"),
Index("ix_proposal_review_proposal_status", "status"),
)
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
document_id: Mapped[str] = mapped_column(Text, nullable=False)
tenant_id: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(
Text, nullable=False, server_default=text("'PENDING'")
)
# Inline proposal-level decision fields
decision_notes: Mapped[str | None] = mapped_column(Text, nullable=True)
decision_officer_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
)
decision_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
jira_synced: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("false")
)
jira_synced_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
review_runs: Mapped[list["ProposalReviewRun"]] = relationship(
"ProposalReviewRun",
back_populates="proposal",
cascade="all, delete-orphan",
)
findings: Mapped[list["ProposalReviewFinding"]] = relationship(
"ProposalReviewFinding",
back_populates="proposal",
cascade="all, delete-orphan",
)
documents: Mapped[list["ProposalReviewDocument"]] = relationship(
"ProposalReviewDocument",
back_populates="proposal",
cascade="all, delete-orphan",
)
class ProposalReviewRun(Base):
__tablename__ = "proposal_review_run"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
proposal_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
ruleset_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_ruleset.id"),
nullable=False,
)
triggered_by: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=False
)
status: Mapped[str] = mapped_column(
Text, nullable=False, server_default=text("'PENDING'")
)
total_rules: Mapped[int] = mapped_column(Integer, nullable=False)
completed_rules: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
failed_rules: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
started_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
completed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
proposal: Mapped["ProposalReviewProposal"] = relationship(
"ProposalReviewProposal", back_populates="review_runs"
)
findings: Mapped[list["ProposalReviewFinding"]] = relationship(
"ProposalReviewFinding",
back_populates="review_run",
cascade="all, delete-orphan",
)
class ProposalReviewFinding(Base):
__tablename__ = "proposal_review_finding"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
proposal_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
rule_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_rule.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
review_run_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_run.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
verdict: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[str | None] = mapped_column(Text, nullable=True)
evidence: Mapped[str | None] = mapped_column(Text, nullable=True)
explanation: Mapped[str | None] = mapped_column(Text, nullable=True)
suggested_action: Mapped[str | None] = mapped_column(Text, nullable=True)
llm_model: Mapped[str | None] = mapped_column(Text, nullable=True)
llm_tokens_used: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Inline per-finding decision fields
decision_action: Mapped[str | None] = mapped_column(Text, nullable=True)
decision_notes: Mapped[str | None] = mapped_column(Text, nullable=True)
decision_officer_id: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
)
decided_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
proposal: Mapped["ProposalReviewProposal"] = relationship(
"ProposalReviewProposal", back_populates="findings"
)
review_run: Mapped["ProposalReviewRun"] = relationship(
"ProposalReviewRun", back_populates="findings"
)
rule: Mapped["ProposalReviewRule"] = relationship("ProposalReviewRule")
class ProposalReviewDocument(Base):
"""Manually uploaded documents or auto-fetched FOAs."""
__tablename__ = "proposal_review_document"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
proposal_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
file_name: Mapped[str] = mapped_column(Text, nullable=False)
file_type: Mapped[str | None] = mapped_column(Text, nullable=True)
file_store_id: Mapped[str | None] = mapped_column(Text, nullable=True)
extracted_text: Mapped[str | None] = mapped_column(Text, nullable=True)
document_role: Mapped[str] = mapped_column(Text, nullable=False)
uploaded_by: Mapped[UUID | None] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
proposal: Mapped["ProposalReviewProposal"] = relationship(
"ProposalReviewProposal", back_populates="documents"
)
class ProposalReviewImportJob(Base):
"""Tracks background checklist import jobs dispatched via Celery."""
__tablename__ = "proposal_review_import_job"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
ruleset_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
ForeignKey("proposal_review_ruleset.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
tenant_id: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(
Text, nullable=False, server_default=text("'PENDING'")
)
source_filename: Mapped[str] = mapped_column(Text, nullable=False)
extracted_text: Mapped[str] = mapped_column(Text, nullable=False)
rules_created: Mapped[int] = mapped_column(
Integer, nullable=False, server_default=text("0")
)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
completed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class ProposalReviewConfig(Base):
"""Admin configuration (one row per tenant)."""
__tablename__ = "proposal_review_config"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True),
primary_key=True,
server_default=text("gen_random_uuid()"),
)
tenant_id: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
jira_connector_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
jira_project_key: Mapped[str | None] = mapped_column(Text, nullable=True)
field_mapping: Mapped[list | None] = mapped_column(PGJSONB(), nullable=True)
jira_writeback: Mapped[dict | None] = mapped_column(PGJSONB(), nullable=True)
review_model: Mapped[str | None] = mapped_column(Text, nullable=True)
import_model: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -1,133 +0,0 @@
"""DB operations for proposal state records."""
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_proposal(
proposal_id: UUID,
tenant_id: str,
db_session: Session,
) -> ProposalReviewProposal | None:
"""Get a proposal by its ID."""
return (
db_session.query(ProposalReviewProposal)
.filter(
ProposalReviewProposal.id == proposal_id,
ProposalReviewProposal.tenant_id == tenant_id,
)
.one_or_none()
)
def get_proposal_by_document_id(
document_id: str,
tenant_id: str,
db_session: Session,
) -> ProposalReviewProposal | None:
"""Get a proposal by its linked document ID."""
return (
db_session.query(ProposalReviewProposal)
.filter(
ProposalReviewProposal.document_id == document_id,
ProposalReviewProposal.tenant_id == tenant_id,
)
.one_or_none()
)
def get_or_create_proposal(
document_id: str,
tenant_id: str,
db_session: Session,
) -> ProposalReviewProposal:
"""Get or lazily create a proposal state record for a document.
This is the primary entry point — the proposal record is created on first
interaction, not when the Jira ticket is ingested.
"""
proposal = get_proposal_by_document_id(document_id, tenant_id, db_session)
if proposal:
return proposal
proposal = ProposalReviewProposal(
document_id=document_id,
tenant_id=tenant_id,
)
db_session.add(proposal)
try:
db_session.flush()
except IntegrityError:
db_session.rollback()
proposal = get_proposal_by_document_id(document_id, tenant_id, db_session)
if proposal is None:
raise
return proposal
logger.info(f"Lazily created proposal {proposal.id} for document {document_id}")
return proposal
def list_proposals(
tenant_id: str,
db_session: Session,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ProposalReviewProposal]:
"""List proposals for a tenant with optional status filter."""
query = (
db_session.query(ProposalReviewProposal)
.filter(ProposalReviewProposal.tenant_id == tenant_id)
.order_by(desc(ProposalReviewProposal.updated_at))
)
if status:
query = query.filter(ProposalReviewProposal.status == status)
return query.offset(offset).limit(limit).all()
def count_proposals(
tenant_id: str,
db_session: Session,
status: str | None = None,
) -> int:
"""Count proposals for a tenant."""
query = db_session.query(ProposalReviewProposal).filter(
ProposalReviewProposal.tenant_id == tenant_id
)
if status:
query = query.filter(ProposalReviewProposal.status == status)
return query.count()
def update_proposal_status(
proposal_id: UUID,
tenant_id: str,
status: str,
db_session: Session,
) -> ProposalReviewProposal | None:
"""Update a proposal's status."""
proposal = (
db_session.query(ProposalReviewProposal)
.filter(
ProposalReviewProposal.id == proposal_id,
ProposalReviewProposal.tenant_id == tenant_id,
)
.one_or_none()
)
if not proposal:
return None
proposal.status = status
proposal.updated_at = datetime.now(timezone.utc)
db_session.flush()
logger.info(f"Updated proposal {proposal_id} status to {status}")
return proposal

View File

@@ -1,337 +0,0 @@
"""DB operations for rulesets and rules."""
from datetime import datetime
from datetime import timezone
from typing import Any
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
from onyx.server.features.proposal_review.db.models import ProposalReviewRuleset
from onyx.utils.logger import setup_logger
logger = setup_logger()
_RULESET_UPDATABLE_FIELDS = frozenset(
{"name", "description", "is_default", "is_active"}
)
_RULE_UPDATABLE_FIELDS = frozenset(
{
"name",
"description",
"category",
"rule_type",
"rule_intent",
"prompt_template",
"authority",
"is_hard_stop",
"priority",
"is_active",
"refinement_needed",
"refinement_question",
}
)
# =============================================================================
# Ruleset CRUD
# =============================================================================
def list_rulesets(
tenant_id: str,
db_session: Session,
active_only: bool = False,
) -> list[ProposalReviewRuleset]:
"""List all rulesets for a tenant."""
query = (
db_session.query(ProposalReviewRuleset)
.filter(ProposalReviewRuleset.tenant_id == tenant_id)
.options(selectinload(ProposalReviewRuleset.rules))
.order_by(desc(ProposalReviewRuleset.created_at))
)
if active_only:
query = query.filter(ProposalReviewRuleset.is_active.is_(True))
return query.all()
def get_ruleset(
ruleset_id: UUID,
tenant_id: str,
db_session: Session,
) -> ProposalReviewRuleset | None:
"""Get a single ruleset by ID with all its rules."""
return (
db_session.query(ProposalReviewRuleset)
.filter(
ProposalReviewRuleset.id == ruleset_id,
ProposalReviewRuleset.tenant_id == tenant_id,
)
.options(selectinload(ProposalReviewRuleset.rules))
.one_or_none()
)
def create_ruleset(
tenant_id: str,
name: str,
db_session: Session,
description: str | None = None,
is_default: bool = False,
created_by: UUID | None = None,
) -> ProposalReviewRuleset:
"""Create a new ruleset."""
# If this ruleset is default, un-default any existing default
if is_default:
_clear_default_ruleset(tenant_id, db_session)
ruleset = ProposalReviewRuleset(
tenant_id=tenant_id,
name=name,
description=description,
is_default=is_default,
created_by=created_by,
)
db_session.add(ruleset)
db_session.flush()
logger.info(f"Created ruleset {ruleset.id} '{name}' for tenant {tenant_id}")
return ruleset
def update_ruleset(
ruleset_id: UUID,
tenant_id: str,
db_session: Session,
updates: dict[str, Any],
) -> ProposalReviewRuleset | None:
"""Update a ruleset. Returns None if not found."""
ruleset = get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
return None
for field, value in updates.items():
if field not in _RULESET_UPDATABLE_FIELDS:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, f"Cannot update field: {field}"
)
if field == "is_default" and value:
_clear_default_ruleset(tenant_id, db_session)
setattr(ruleset, field, value)
ruleset.updated_at = datetime.now(timezone.utc)
db_session.flush()
return ruleset
def delete_ruleset(
ruleset_id: UUID,
tenant_id: str,
db_session: Session,
) -> bool:
"""Delete a ruleset. Returns False if not found."""
ruleset = get_ruleset(ruleset_id, tenant_id, db_session)
if not ruleset:
return False
db_session.delete(ruleset)
db_session.flush()
logger.info(f"Deleted ruleset {ruleset_id}")
return True
def _clear_default_ruleset(tenant_id: str, db_session: Session) -> None:
"""Un-default any existing default ruleset for a tenant."""
db_session.query(ProposalReviewRuleset).filter(
ProposalReviewRuleset.tenant_id == tenant_id,
ProposalReviewRuleset.is_default.is_(True),
).update({ProposalReviewRuleset.is_default: False})
db_session.flush()
# =============================================================================
# Rule CRUD
# =============================================================================
def list_rules_by_ruleset(
ruleset_id: UUID,
db_session: Session,
active_only: bool = False,
) -> list[ProposalReviewRule]:
"""List all rules in a ruleset."""
query = (
db_session.query(ProposalReviewRule)
.filter(ProposalReviewRule.ruleset_id == ruleset_id)
.order_by(ProposalReviewRule.priority)
)
if active_only:
query = query.filter(ProposalReviewRule.is_active.is_(True))
return query.all()
def get_rule(
rule_id: UUID,
db_session: Session,
) -> ProposalReviewRule | None:
"""Get a single rule by ID."""
return (
db_session.query(ProposalReviewRule)
.filter(ProposalReviewRule.id == rule_id)
.one_or_none()
)
def get_rule_with_tenant_check(
rule_id: UUID,
tenant_id: str,
db_session: Session,
) -> ProposalReviewRule | None:
"""Get a single rule by ID, validating it belongs to the given tenant.
Joins with the ruleset table so the tenant check happens in one query,
eliminating the race between separate get_rule + get_ruleset calls.
"""
return (
db_session.query(ProposalReviewRule)
.join(
ProposalReviewRuleset,
ProposalReviewRule.ruleset_id == ProposalReviewRuleset.id,
)
.filter(
ProposalReviewRule.id == rule_id,
ProposalReviewRuleset.tenant_id == tenant_id,
)
.one_or_none()
)
def create_rule(
ruleset_id: UUID,
name: str,
rule_type: str,
prompt_template: str,
db_session: Session,
description: str | None = None,
category: str | None = None,
rule_intent: str = "CHECK",
source: str = "MANUAL",
authority: str | None = None,
is_hard_stop: bool = False,
priority: int = 0,
refinement_needed: bool = False,
refinement_question: str | None = None,
) -> ProposalReviewRule:
"""Create a new rule within a ruleset."""
rule = ProposalReviewRule(
ruleset_id=ruleset_id,
name=name,
description=description,
category=category,
rule_type=rule_type,
rule_intent=rule_intent,
prompt_template=prompt_template,
source=source,
authority=authority,
is_hard_stop=is_hard_stop,
priority=priority,
refinement_needed=refinement_needed,
refinement_question=refinement_question,
)
db_session.add(rule)
db_session.flush()
logger.info(f"Created rule {rule.id} '{name}' in ruleset {ruleset_id}")
return rule
def update_rule(
rule_id: UUID,
db_session: Session,
updates: dict[str, Any],
) -> ProposalReviewRule | None:
"""Update a rule. Returns None if not found."""
rule = get_rule(rule_id, db_session)
if not rule:
return None
for field, value in updates.items():
if field not in _RULE_UPDATABLE_FIELDS:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT, f"Cannot update field: {field}"
)
setattr(rule, field, value)
rule.updated_at = datetime.now(timezone.utc)
db_session.flush()
return rule
def delete_rule(
rule_id: UUID,
db_session: Session,
) -> bool:
"""Delete a rule. Returns False if not found."""
rule = get_rule(rule_id, db_session)
if not rule:
return False
db_session.delete(rule)
db_session.flush()
logger.info(f"Deleted rule {rule_id}")
return True
def bulk_update_rules(
rule_ids: list[UUID],
action: str,
ruleset_id: UUID,
db_session: Session,
) -> int:
"""Batch activate/deactivate/delete rules.
Args:
rule_ids: list of rule IDs
action: "activate" | "deactivate" | "delete"
ruleset_id: scope operations to rules within this ruleset
Returns:
number of rules affected
"""
base_query = db_session.query(ProposalReviewRule).filter(
ProposalReviewRule.id.in_(rule_ids),
ProposalReviewRule.ruleset_id == ruleset_id,
)
if action == "delete":
count = base_query.delete(synchronize_session="fetch")
elif action in ("activate", "deactivate"):
count = base_query.update(
{
ProposalReviewRule.is_active: action == "activate",
ProposalReviewRule.updated_at: datetime.now(timezone.utc),
},
synchronize_session="fetch",
)
else:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, f"Unknown bulk action: {action}")
db_session.flush()
logger.info(f"Bulk {action} on {count} rules")
return count
def count_active_rules(
ruleset_id: UUID,
db_session: Session,
) -> int:
"""Count active rules in a ruleset."""
return (
db_session.query(ProposalReviewRule)
.filter(
ProposalReviewRule.ruleset_id == ruleset_id,
ProposalReviewRule.is_active.is_(True),
)
.count()
)

View File

@@ -1 +0,0 @@
"""Proposal Review Engine — AI-powered proposal evaluation."""

View File

@@ -1,497 +0,0 @@
"""Parses uploaded checklist documents into atomic review rules via LLM.
Uses a two-pass approach to handle checklists of any size without hitting
output token limits:
Pass 1 — Enumerate: Identify all distinct checklist items from the
document (names, categories, sub-checks). This produces a small,
bounded output regardless of document size.
Pass 2 — Decompose: For each identified item, make a focused LLM call
to generate atomic review rules with full prompt templates.
Each call produces 15 rules, well within token limits.
Callers orchestrate persistence — this module is pure LLM + parsing, no
DB access, no callbacks, no threads.
"""
import json
import re
from dataclasses import dataclass
from dataclasses import field
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.llm.interfaces import LLM
from onyx.llm.models import SystemMessage
from onyx.llm.models import UserMessage
from onyx.llm.utils import get_llm_max_output_tokens
from onyx.llm.utils import get_model_map
from onyx.llm.utils import llm_response_to_string
from onyx.tracing.llm_utils import llm_generation_span
from onyx.tracing.llm_utils import record_llm_response
from onyx.utils.logger import setup_logger
logger = setup_logger()
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class ChecklistItem:
"""A single checklist item identified during pass 1."""
id: str
name: str
category: str
description: str
sub_checks: list[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Prompts — Pass 1 (Enumerate)
# ---------------------------------------------------------------------------
_ENUMERATE_SYSTEM = """\
You are an expert at analyzing institutional review checklists for university \
grant offices. Your task is to read a checklist document and identify every \
distinct checklist item or section that requires review."""
_ENUMERATE_USER = """\
Read the checklist document below and list every distinct checklist item.
CHECKLIST DOCUMENT:
---
{checklist_text}
---
For each item, provide:
- **id**: A short identifier derived from the document (e.g., "IR-1", \
"KR-3", "Section-A.2"). Invent one if the document doesn't assign one.
- **name**: The item's title or heading.
- **category**: A display label combining the id and name \
(e.g., "IR-2: Regulatory Compliance").
- **description**: One sentence summarizing what this item covers.
- **sub_checks**: A list of the individual checks or requirements \
mentioned under this item. Be thorough — include every distinct \
requirement even if the document groups them together.
Respond with ONLY a valid JSON array:
[
{{
"id": "IR-1",
"name": "Institutional and PI Eligibility",
"category": "IR-1: Institutional and PI Eligibility Requirements",
"description": "Verify institution and PI meet sponsor eligibility.",
"sub_checks": ["Institutional eligibility", "PI eligibility", ...]
}},
...
]"""
# ---------------------------------------------------------------------------
# Prompts — Pass 2 (Decompose one item into rules)
# ---------------------------------------------------------------------------
_DECOMPOSE_SYSTEM = """\
You are an expert at creating AI review rules for university grant proposal \
review. Each rule you create will be independently evaluated by an LLM \
against a grant proposal. Rules must be atomic (one criterion each) and \
self-contained (the prompt template includes all context needed).
Variable placeholders available for prompt templates:
{{{{proposal_text}}}} — full proposal and supporting documents
{{{{budget_text}}}} — budget / financial sections
{{{{foa_text}}}} — funding opportunity announcement
{{{{metadata}}}} — structured metadata (PI, sponsor, etc.)
{{{{metadata.FIELD_NAME}}}} — a specific metadata field
Rule types:
DOCUMENT_CHECK — verify presence / content in documents
METADATA_CHECK — validate a structured metadata field
CROSS_REFERENCE — compare information across documents
CUSTOM_NL — natural language evaluation
Rule intents:
CHECK — pass / fail criterion
HIGHLIGHT — informational flag (no pass / fail)
If a rule requires institution-specific info NOT present in the checklist \
(IDC rates, mandatory cost categories, local policies, etc.), set \
refinement_needed=true and include a refinement_question."""
_DECOMPOSE_USER = """\
Create atomic review rules for the checklist item described below.
ITEM TO DECOMPOSE:
ID: {item_id}
Name: {item_name}
Category: {item_category}
Description: {item_description}
Sub-checks: {sub_checks}
FULL CHECKLIST (for context — only create rules for the item above):
---
{checklist_text}
---
Generate one rule per sub-check. Each rule object must have:
{{
"name": "Short descriptive name (max 100 chars)",
"description": "What this rule checks",
"category": "{item_category}",
"rule_type": "DOCUMENT_CHECK | METADATA_CHECK | CROSS_REFERENCE | CUSTOM_NL",
"rule_intent": "CHECK | HIGHLIGHT",
"prompt_template": "Self-contained prompt with {{{{variable}}}} placeholders.",
"refinement_needed": false,
"refinement_question": null
}}
Respond with ONLY a valid JSON array of rule objects."""
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def enumerate_checklist_items(
checklist_text: str,
llm: LLM,
) -> list[ChecklistItem]:
"""Pass 1: Identify all distinct checklist items from the document.
One LLM call. Output is small and bounded regardless of document size.
Args:
checklist_text: Full text extracted from the uploaded checklist file.
llm: The LLM instance to use.
Returns:
Ordered list of checklist items found in the document.
Raises:
RuntimeError: If the LLM call fails or returns unparseable output.
"""
user_content = _ENUMERATE_USER.format(checklist_text=checklist_text)
messages = [
SystemMessage(content=_ENUMERATE_SYSTEM),
UserMessage(content=user_content),
]
max_output_tokens = _get_max_output_tokens(llm)
try:
with llm_generation_span(llm, "checklist_enumerate", messages) as gen_span:
response = llm.invoke(
messages, timeout_override=300, max_tokens=max_output_tokens
)
record_llm_response(gen_span, response)
raw_text = llm_response_to_string(response)
except Exception as e:
logger.error(f"Pass 1 (enumerate) LLM call failed: {e}")
raise RuntimeError(f"Failed to enumerate checklist items: {str(e)}") from e
parsed = _parse_json_array(raw_text, context="enumerate")
items: list[ChecklistItem] = []
for i, raw in enumerate(parsed):
if not isinstance(raw, dict):
logger.warning(f"Enumerate: skipping non-dict at index {i}")
continue
item_id = str(raw.get("id", f"ITEM-{i + 1}"))
name = raw.get("name")
if not name:
logger.warning(f"Enumerate: skipping item at index {i} (no name)")
continue
items.append(
ChecklistItem(
id=item_id,
name=str(name),
category=str(raw.get("category", f"{item_id}: {name}")),
description=str(raw.get("description", "")),
sub_checks=[str(s) for s in raw.get("sub_checks", [])],
)
)
return items
def decompose_checklist_item(
item: ChecklistItem,
checklist_text: str,
llm: LLM,
) -> list[dict]:
"""Pass 2: Decompose one checklist item into atomic review rules.
One LLM call. Output is bounded (110 rules per item).
Args:
item: The checklist item to decompose.
checklist_text: Full checklist text (passed as context).
llm: The LLM instance to use.
Returns:
List of validated rule dicts for this item.
Raises:
RuntimeError: If the LLM call fails or returns unparseable output.
"""
sub_checks_str = "\n".join(f" - {s}" for s in item.sub_checks) or " (none listed)"
user_content = _DECOMPOSE_USER.format(
item_id=item.id,
item_name=item.name,
item_category=item.category,
item_description=item.description,
sub_checks=sub_checks_str,
checklist_text=checklist_text,
)
messages = [
SystemMessage(content=_DECOMPOSE_SYSTEM),
UserMessage(content=user_content),
]
max_output_tokens = _get_max_output_tokens(llm)
try:
with llm_generation_span(
llm, f"checklist_decompose_{item.id}", messages
) as gen_span:
response = llm.invoke(
messages, timeout_override=300, max_tokens=max_output_tokens
)
record_llm_response(gen_span, response)
raw_text = llm_response_to_string(response)
except Exception as e:
raise RuntimeError(f"LLM call failed for item '{item.name}': {str(e)}") from e
parsed = _parse_json_array(raw_text, context=f"decompose[{item.id}]")
rules: list[dict] = []
for i, raw_rule in enumerate(parsed):
if not isinstance(raw_rule, dict):
continue
rule = _validate_rule(raw_rule, i)
if rule:
if not rule["category"]:
rule["category"] = item.category
rules.append(rule)
return rules
# ---------------------------------------------------------------------------
# Prompts — Refinement (single rule)
# ---------------------------------------------------------------------------
_REFINE_SYSTEM = """\
You are an expert at creating AI review rules for university grant proposal \
review. You are refining a rule that was previously flagged as needing \
institution-specific information. The user has now provided that information.
Variable placeholders available for prompt templates:
{{{{proposal_text}}}} — full proposal and supporting documents
{{{{budget_text}}}} — budget / financial sections
{{{{foa_text}}}} — funding opportunity announcement
{{{{metadata}}}} — structured metadata (PI, sponsor, etc.)
{{{{metadata.FIELD_NAME}}}} — a specific metadata field
Rule types:
DOCUMENT_CHECK — verify presence / content in documents
METADATA_CHECK — validate a structured metadata field
CROSS_REFERENCE — compare information across documents
CUSTOM_NL — natural language evaluation
Rule intents:
CHECK — pass / fail criterion
HIGHLIGHT — informational flag (no pass / fail)"""
_REFINE_USER = """\
The following rule was imported from a checklist but flagged as needing \
additional information before it can be used.
CURRENT RULE:
Name: {rule_name}
Description: {rule_description}
Prompt Template: {rule_prompt_template}
QUESTION THAT WAS ASKED:
{refinement_question}
USER'S ANSWER:
{user_answer}
Using the user's answer, produce a refined version of this rule. \
Incorporate the institution-specific information into the prompt_template \
so the rule is fully self-contained and no longer needs refinement.
Respond with ONLY a single JSON object (not an array):
{{
"name": "Short descriptive name (max 100 chars)",
"description": "What this rule checks",
"rule_type": "DOCUMENT_CHECK | METADATA_CHECK | CROSS_REFERENCE | CUSTOM_NL",
"rule_intent": "CHECK | HIGHLIGHT",
"prompt_template": "Refined self-contained prompt with {{{{variable}}}} placeholders.",
"refinement_needed": false,
"refinement_question": null
}}"""
def refine_rule(
rule_name: str,
rule_description: str | None,
rule_prompt_template: str,
refinement_question: str,
user_answer: str,
llm: LLM,
) -> dict:
"""Refine a single rule using the user's answer to the refinement question.
One LLM call. Returns a validated rule dict with refinement_needed=False.
Raises:
RuntimeError: If the LLM call fails or returns unparseable output.
"""
user_content = _REFINE_USER.format(
rule_name=rule_name,
rule_description=rule_description or "(none)",
rule_prompt_template=rule_prompt_template,
refinement_question=refinement_question,
user_answer=user_answer,
)
messages = [
SystemMessage(content=_REFINE_SYSTEM),
UserMessage(content=user_content),
]
try:
with llm_generation_span(llm, "checklist_refine_rule", messages) as gen_span:
response = llm.invoke(messages, timeout_override=120)
record_llm_response(gen_span, response)
raw_text = llm_response_to_string(response)
except Exception as e:
raise RuntimeError(f"LLM call failed during rule refinement: {str(e)}") from e
# Parse the single JSON object (strip code fences)
text = raw_text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
text = re.sub(r"\n?```\s*$", "", text)
text = text.strip()
try:
parsed = json.loads(text)
except json.JSONDecodeError as e:
raise RuntimeError(f"LLM returned invalid JSON during refinement: {e}") from e
if isinstance(parsed, list):
if not parsed:
raise RuntimeError("LLM returned an empty array during refinement")
parsed = parsed[0]
if not isinstance(parsed, dict):
raise RuntimeError("LLM returned non-object JSON during refinement")
rule = _validate_rule(parsed, 0)
if not rule:
raise RuntimeError("LLM returned an invalid rule during refinement")
# Force refinement_needed to False
rule["refinement_needed"] = False
rule["refinement_question"] = None
return rule
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _get_max_output_tokens(llm: LLM) -> int:
"""Look up the model's max output tokens from litellm's model cost map."""
try:
model_map = get_model_map()
return get_llm_max_output_tokens(
model_map=model_map,
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
except Exception as e:
logger.warning(f"Failed to resolve max output tokens: {e}")
return int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
def _parse_json_array(raw_text: str, context: str) -> list:
"""Parse an LLM response as a JSON array, stripping code fences."""
text = raw_text.strip()
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
text = re.sub(r"\n?```\s*$", "", text)
text = text.strip()
try:
parsed = json.loads(text)
except json.JSONDecodeError as e:
logger.error(f"[{context}] Failed to parse JSON: {e}")
logger.debug(f"[{context}] Raw LLM response: {text[:500]}...")
raise RuntimeError(
f"LLM returned invalid JSON during {context}. "
"Please try the import again."
) from e
if not isinstance(parsed, list):
raise RuntimeError(
f"LLM returned non-array JSON during {context}. " "Expected a list."
)
return parsed
def _validate_rule(raw_rule: dict, index: int) -> dict | None:
"""Validate and normalize a single parsed rule dict."""
valid_types = {
"DOCUMENT_CHECK",
"METADATA_CHECK",
"CROSS_REFERENCE",
"CUSTOM_NL",
}
valid_intents = {"CHECK", "HIGHLIGHT"}
name = raw_rule.get("name")
if not name:
logger.warning(f"Rule at index {index} missing 'name', skipping")
return None
prompt_template = raw_rule.get("prompt_template")
if not prompt_template:
logger.warning(f"Rule '{name}' missing 'prompt_template', skipping")
return None
rule_type = str(raw_rule.get("rule_type", "CUSTOM_NL")).upper()
if rule_type not in valid_types:
rule_type = "CUSTOM_NL"
rule_intent = str(raw_rule.get("rule_intent", "CHECK")).upper()
if rule_intent not in valid_intents:
rule_intent = "CHECK"
return {
"name": str(name)[:200],
"description": raw_rule.get("description"),
"category": raw_rule.get("category"),
"rule_type": rule_type,
"rule_intent": rule_intent,
"prompt_template": str(prompt_template),
"refinement_needed": bool(raw_rule.get("refinement_needed", False)),
"refinement_question": (
str(raw_rule["refinement_question"])
if raw_rule.get("refinement_question")
else None
),
}

View File

@@ -1,340 +0,0 @@
"""Assembles all available text content for a proposal to pass to rule evaluation.
V1 LIMITATION: Document body text (the main text content extracted by connectors)
is stored in Vespa, not in the PostgreSQL Document table. The DB row only stores
metadata (semantic_id, link, doc_metadata, primary_owners, etc.). For Jira tickets,
the Description and Comments text are indexed into Vespa during connector runs and
are NOT accessible here without a Vespa query.
As a result, the primary source of rich text for rule evaluation in V1 is:
- Manually uploaded documents (proposal_review_document.extracted_text)
- Structured metadata from the Document row's doc_metadata JSONB column
- For Jira tickets: the connector populates doc_metadata with field values,
which often includes Description, Status, Priority, Assignee, etc.
Future improvement: add a Vespa retrieval step to fetch indexed text chunks for
the parent document and its attachments.
"""
import json
from dataclasses import dataclass
from dataclasses import field
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import Document
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Metadata keys from Jira connector that commonly carry useful text content.
# These are extracted from doc_metadata and presented as labeled sections to
# give the LLM more signal when evaluating rules.
_JIRA_TEXT_METADATA_KEYS = [
"description",
"summary",
"comment",
"comments",
"acceptance_criteria",
"story_points",
"priority",
"status",
"resolution",
"issue_type",
"labels",
"components",
"fix_versions",
"affects_versions",
"environment",
"assignee",
"reporter",
"creator",
]
@dataclass
class ProposalContext:
"""All text and metadata context assembled for rule evaluation."""
proposal_text: str # concatenated text from all documents
budget_text: str # best-effort budget section extraction
foa_text: str # FOA content (auto-fetched or uploaded)
metadata: dict # structured metadata from Document.doc_metadata
jira_key: str # for display/reference
metadata_raw: dict = field(default_factory=dict) # full unresolved metadata
def get_proposal_context(
proposal_id: UUID,
db_session: Session,
) -> ProposalContext:
"""Assemble context for rule evaluation.
Gathers text from three sources:
1. Jira ticket content (from Document.semantic_id + doc_metadata)
2. Jira attachments (child Documents linked by ID prefix convention)
3. Manually uploaded documents (from proposal_review_document.extracted_text)
For MVP, returns full text of everything. Future: smart section selection.
"""
# 1. Get the proposal record to find the linked document_id
proposal = (
db_session.query(ProposalReviewProposal)
.filter(ProposalReviewProposal.id == proposal_id)
.one_or_none()
)
if not proposal:
logger.warning(f"Proposal {proposal_id} not found during context assembly")
return ProposalContext(
proposal_text="",
budget_text="",
foa_text="",
metadata={},
jira_key="",
metadata_raw={},
)
# 2. Fetch the parent Document (Jira ticket)
parent_doc = (
db_session.query(Document)
.filter(Document.id == proposal.document_id)
.one_or_none()
)
jira_key = ""
metadata: dict = {}
all_text_parts: list[str] = []
budget_parts: list[str] = []
foa_parts: list[str] = []
if parent_doc:
jira_key = parent_doc.semantic_id or ""
metadata = parent_doc.doc_metadata or {}
# Build text from DB-available fields. The actual ticket body text lives
# in Vespa and is not accessible here. The doc_metadata JSONB column
# often contains structured Jira fields that the connector extracted.
parent_text = _build_parent_document_text(parent_doc)
if parent_text:
all_text_parts.append(parent_text)
# 3. Look for child Documents (Jira attachments).
# Jira attachment Documents have IDs of the form:
# "{parent_jira_url}/attachments/{attachment_id}"
# We find them via ID prefix match.
#
# V1 LIMITATION: child document text content is in Vespa, not in the
# DB. We can only extract metadata (filename, mime type, etc.) from
# the Document row. The actual attachment text is not available here
# without a Vespa query. See module docstring for details.
child_docs = _find_child_documents(parent_doc, db_session)
if child_docs:
logger.info(
f"Found {len(child_docs)} child documents for {jira_key}. "
f"Note: their text content is in Vespa and only metadata is "
f"available for rule evaluation."
)
for child_doc in child_docs:
child_text = _build_child_document_text(child_doc)
if child_text:
all_text_parts.append(child_text)
_classify_child_text(child_doc, child_text, budget_parts, foa_parts)
else:
logger.warning(
f"Parent Document not found for proposal {proposal_id} "
f"(document_id={proposal.document_id}). "
f"Context will rely on manually uploaded documents only."
)
# 4. Fetch manually uploaded documents from proposal_review_document.
# This is the PRIMARY source of rich text content for V1 since the
# extracted_text column holds the full document content.
manual_docs = (
db_session.query(ProposalReviewDocument)
.filter(ProposalReviewDocument.proposal_id == proposal_id)
.order_by(ProposalReviewDocument.created_at)
.all()
)
for doc in manual_docs:
if doc.extracted_text:
all_text_parts.append(
f"--- Document: {doc.file_name} (role: {doc.document_role}) ---\n"
f"{doc.extracted_text}"
)
# Classify by role
role_upper = (doc.document_role or "").upper()
if role_upper == "BUDGET" or _is_budget_filename(doc.file_name):
budget_parts.append(doc.extracted_text)
elif role_upper == "FOA":
foa_parts.append(doc.extracted_text)
return ProposalContext(
proposal_text="\n\n".join(all_text_parts) if all_text_parts else "",
budget_text="\n\n".join(budget_parts) if budget_parts else "",
foa_text="\n\n".join(foa_parts) if foa_parts else "",
metadata=metadata,
jira_key=jira_key,
metadata_raw=metadata,
)
def _build_parent_document_text(doc: Document) -> str:
"""Build text representation from a parent Document row (Jira ticket).
The Document table does NOT store the ticket body text -- that lives in Vespa.
What we DO have access to:
- semantic_id: typically "{ISSUE_KEY}: {summary}"
- link: URL to the Jira ticket
- doc_metadata: JSONB with structured fields from the connector (may include
description, status, priority, assignee, custom fields, etc.)
- primary_owners / secondary_owners: people associated with the document
We extract all available metadata and present it as labeled sections to
maximize the signal available to the LLM for rule evaluation.
"""
parts: list[str] = []
if doc.semantic_id:
parts.append(f"Document: {doc.semantic_id}")
if doc.link:
parts.append(f"Link: {doc.link}")
# Include owner information which may be useful for compliance checks
if doc.primary_owners:
parts.append(f"Primary Owners: {', '.join(doc.primary_owners)}")
if doc.secondary_owners:
parts.append(f"Secondary Owners: {', '.join(doc.secondary_owners)}")
# doc_metadata contains structured data from the Jira connector.
# Extract well-known text-bearing fields first, then include the rest.
if doc.doc_metadata:
metadata = doc.doc_metadata
# Extract well-known Jira fields as labeled sections
for key in _JIRA_TEXT_METADATA_KEYS:
value = metadata.get(key)
if value is not None and value != "" and value != []:
label = key.replace("_", " ").title()
if isinstance(value, list):
parts.append(f"{label}: {', '.join(str(v) for v in value)}")
elif isinstance(value, dict):
parts.append(
f"{label}:\n{json.dumps(value, indent=2, default=str)}"
)
else:
parts.append(f"{label}: {value}")
# Include any remaining metadata keys not in the well-known set,
# so custom fields and connector-specific data are not lost.
remaining = {
k: v
for k, v in metadata.items()
if k.lower() not in _JIRA_TEXT_METADATA_KEYS
and v is not None
and v != ""
and v != []
}
if remaining:
parts.append(
f"Additional Metadata:\n"
f"{json.dumps(remaining, indent=2, default=str)}"
)
return "\n".join(parts) if parts else ""
def _build_child_document_text(doc: Document) -> str:
"""Build text representation from a child Document row (Jira attachment).
V1 LIMITATION: The actual extracted text of the attachment lives in Vespa,
not in the Document table. We can only present the metadata that the
connector stored in doc_metadata (filename, mime type, size, parent ticket).
This means the LLM knows an attachment EXISTS and its metadata, but cannot
read its contents. Future versions should add a Vespa retrieval step.
"""
parts: list[str] = []
if doc.semantic_id:
parts.append(f"Attachment: {doc.semantic_id}")
if doc.link:
parts.append(f"Link: {doc.link}")
# Child document metadata typically includes:
# parent_ticket, attachment_filename, attachment_mime_type, attachment_size
if doc.doc_metadata:
for key, value in doc.doc_metadata.items():
if value is not None and value != "":
label = key.replace("_", " ").title()
parts.append(f"{label}: {value}")
if not parts:
return ""
# Note the limitation inline for the LLM context
parts.append(
"[Note: Full attachment text is indexed in Vespa and not available "
"in this context. Upload the document manually for full text analysis.]"
)
return "\n".join(parts)
def _find_child_documents(parent_doc: Document, db_session: Session) -> list[Document]:
"""Find child Documents linked to the parent (e.g. Jira attachments).
Jira attachments are indexed as separate Document rows whose ID follows
the convention: "{parent_document_id}/attachments/{attachment_id}".
The parent_document_id for Jira is the full URL to the issue, e.g.
"https://jira.example.com/browse/PROJ-123".
V1 LIMITATION: These child Document rows only contain metadata in the DB.
Their actual extracted text content is stored in Vespa. To read the
attachment text, a Vespa query would be required. This is not implemented
in V1 -- officers should upload key documents manually for full text
analysis.
"""
if not parent_doc.id:
return []
# Child documents have IDs that start with the parent document's ID
# followed by a path segment (e.g., /attachments/12345)
# Escape LIKE wildcards in the document ID
escaped_id = parent_doc.id.replace("%", r"\%").replace("_", r"\_")
child_docs = (
db_session.query(Document)
.filter(
Document.id.like(f"{escaped_id}/%"),
Document.id != parent_doc.id,
)
.all()
)
return child_docs
def _classify_child_text(
doc: Document,
text: str,
budget_parts: list[str],
foa_parts: list[str],
) -> None:
"""Best-effort classification of child document text into budget or FOA."""
semantic_id = (doc.semantic_id or "").lower()
if _is_budget_filename(semantic_id):
budget_parts.append(text)
elif any(
term in semantic_id
for term in ["foa", "funding opportunity", "rfa", "solicitation", "nofo"]
):
foa_parts.append(text)
def _is_budget_filename(filename: str) -> bool:
"""Check if a filename suggests budget content."""
lower = (filename or "").lower()
return any(term in lower for term in ["budget", "cost", "financial", "expenditure"])

View File

@@ -1,168 +0,0 @@
"""Auto-fetches Funding Opportunity Announcements using Onyx web search infrastructure."""
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Map known opportunity ID prefixes to federal agency domains
_AGENCY_DOMAINS: dict[str, str] = {
"RFA": "grants.nih.gov",
"PA": "grants.nih.gov",
"PAR": "grants.nih.gov",
"R01": "grants.nih.gov",
"R21": "grants.nih.gov",
"U01": "grants.nih.gov",
"NOT": "grants.nih.gov",
"NSF": "nsf.gov",
"DE-FOA": "energy.gov",
"HRSA": "hrsa.gov",
"W911": "grants.gov", # DoD
"FA": "grants.gov", # Air Force
"N00": "grants.gov", # Navy
"NOFO": "grants.gov",
}
def fetch_foa(
opportunity_id: str,
proposal_id: UUID,
db_session: Session,
) -> str | None:
"""Fetch FOA content given an opportunity ID.
1. Determine domain from ID prefix (RFA/PA -> nih.gov, NSF -> nsf.gov, etc.)
2. Build search query
3. Call Onyx web search provider
4. Fetch full content from best URL
5. Save as proposal_review_document with role=FOA
6. Return extracted text or None
If the web search provider is not configured, logs a warning and returns None.
"""
if not opportunity_id or not opportunity_id.strip():
logger.debug("No opportunity_id provided, skipping FOA fetch")
return None
opportunity_id = opportunity_id.strip()
# Check if we already have an FOA document for this proposal
existing_foa = (
db_session.query(ProposalReviewDocument)
.filter(
ProposalReviewDocument.proposal_id == proposal_id,
ProposalReviewDocument.document_role == "FOA",
)
.first()
)
if existing_foa and existing_foa.extracted_text:
logger.info(
f"FOA document already exists for proposal {proposal_id}, skipping fetch"
)
return existing_foa.extracted_text
# Determine search domain from opportunity ID prefix
site_domain = _determine_domain(opportunity_id)
# Build search query
search_query = f"{opportunity_id} funding opportunity announcement"
if site_domain:
search_query = f"site:{site_domain} {opportunity_id}"
# Try to get the web search provider
try:
from onyx.tools.tool_implementations.web_search.providers import (
get_default_provider,
)
provider = get_default_provider()
except Exception as e:
logger.warning(f"Failed to load web search provider: {e}")
provider = None
if provider is None:
logger.warning(
"No web search provider configured. Cannot auto-fetch FOA. "
"Configure a web search provider in Admin settings to enable this feature."
)
return None
# Search for the FOA
try:
results = provider.search(search_query)
except Exception as e:
logger.error(f"Web search failed for FOA '{opportunity_id}': {e}")
return None
if not results:
logger.info(f"No search results found for FOA '{opportunity_id}'")
return None
# Pick the best result URL
best_url = str(results[0].link)
logger.info(f"Fetching FOA content from: {best_url}")
# Fetch full content from the URL
try:
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
OnyxWebCrawler,
)
crawler = OnyxWebCrawler()
contents = crawler.contents([best_url])
if (
not contents
or not contents[0].scrape_successful
or not contents[0].full_content
):
logger.warning(f"No content extracted from FOA URL: {best_url}")
return None
foa_text = contents[0].full_content
except Exception as e:
logger.error(f"Failed to fetch FOA content from {best_url}: {e}")
return None
# Save as a proposal_review_document with role=FOA
try:
foa_doc = ProposalReviewDocument(
proposal_id=proposal_id,
file_name=f"FOA_{opportunity_id}.html",
file_type="HTML",
document_role="FOA",
extracted_text=foa_text,
# uploaded_by is None for auto-fetched documents
)
db_session.add(foa_doc)
db_session.flush()
logger.info(
f"Saved FOA document for proposal {proposal_id} "
f"(opportunity_id={opportunity_id}, {len(foa_text)} chars)"
)
except Exception as e:
logger.error(f"Failed to save FOA document: {e}")
# Still return the text even if save fails
return foa_text
return foa_text
def _determine_domain(opportunity_id: str) -> str | None:
"""Determine the likely agency domain from the opportunity ID prefix."""
upper_id = opportunity_id.upper()
for prefix, domain in _AGENCY_DOMAINS.items():
if upper_id.startswith(prefix):
return domain
# If it looks like a grants.gov number (numeric), try grants.gov
if opportunity_id.replace("-", "").isdigit():
return "grants.gov"
return None

View File

@@ -1,391 +0,0 @@
"""Writes officer decisions back to Jira."""
from datetime import datetime
from datetime import timezone
from uuid import UUID
import requests
from sqlalchemy.orm import Session
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import (
fetch_connector_credential_pair_for_connector,
)
from onyx.db.models import Document
from onyx.server.features.proposal_review.db import config as config_db
from onyx.server.features.proposal_review.db import decisions as decisions_db
from onyx.server.features.proposal_review.db import findings as findings_db
from onyx.server.features.proposal_review.db import proposals as proposals_db
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def sync_to_jira(
proposal_id: UUID,
db_session: Session,
) -> None:
"""Write the officer's final decision back to Jira.
Performs up to 3 Jira API operations:
1. PUT custom fields (decision, completion %)
2. POST transition (move to configured column)
3. POST comment (structured review summary)
Then marks the proposal as jira_synced.
Raises:
ValueError: If required config/data is missing.
RuntimeError: If Jira API calls fail.
"""
tenant_id = get_current_tenant_id()
# Load proposal
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
if not proposal:
raise ValueError(f"Proposal {proposal_id} not found")
if not proposal.decision_at:
raise ValueError(f"No decision found for proposal {proposal_id}")
if proposal.jira_synced:
logger.info(f"Decision for proposal {proposal_id} already synced to Jira")
return
# Load tenant config for Jira settings
config = config_db.get_config(tenant_id, db_session)
if not config:
raise ValueError("Proposal review config not found for this tenant")
if not config.jira_connector_id:
raise ValueError(
"No Jira connector configured. Set jira_connector_id in proposal review settings."
)
writeback_config = config.jira_writeback or {}
# Get the Jira issue key from the linked Document
parent_doc = (
db_session.query(Document)
.filter(Document.id == proposal.document_id)
.one_or_none()
)
if not parent_doc:
raise ValueError(f"Linked document {proposal.document_id} not found")
# semantic_id is formatted as "KEY-123: Summary text" by the Jira connector.
# Extract just the issue key (everything before the first colon).
raw_id = parent_doc.semantic_id
if not raw_id:
raise ValueError(
f"Document {proposal.document_id} has no semantic_id (Jira issue key)"
)
issue_key = raw_id.split(":")[0].strip()
# Get Jira credentials from the connector
jira_base_url, auth_headers = _get_jira_credentials(
config.jira_connector_id, db_session
)
# Get findings for the summary
latest_run = findings_db.get_latest_review_run(proposal_id, db_session)
all_findings: list[ProposalReviewFinding] = []
if latest_run:
all_findings = findings_db.list_findings_by_run(latest_run.id, db_session)
# Calculate summary counts
verdict_counts = _count_verdicts(all_findings)
# Operation 1: Update custom fields
_update_custom_fields(
jira_base_url=jira_base_url,
auth_headers=auth_headers,
issue_key=issue_key,
decision=proposal.status,
verdict_counts=verdict_counts,
writeback_config=writeback_config,
)
# Operation 2: Transition the issue
_transition_issue(
jira_base_url=jira_base_url,
auth_headers=auth_headers,
issue_key=issue_key,
decision=proposal.status,
writeback_config=writeback_config,
)
# Operation 3: Post review summary comment
_post_comment(
jira_base_url=jira_base_url,
auth_headers=auth_headers,
issue_key=issue_key,
proposal=proposal,
verdict_counts=verdict_counts,
findings=all_findings,
)
# Mark as synced
decisions_db.mark_proposal_jira_synced(proposal_id, tenant_id, db_session)
db_session.flush()
logger.info(
f"Successfully synced decision for proposal {proposal_id} to Jira issue {issue_key}"
)
def _get_jira_credentials(
connector_id: int,
db_session: Session,
) -> tuple[str, dict[str, str]]:
"""Extract Jira base URL and auth headers from the connector's credentials.
Returns:
Tuple of (jira_base_url, auth_headers_dict).
"""
connector = fetch_connector_by_id(connector_id, db_session)
if not connector:
raise ValueError(f"Jira connector {connector_id} not found")
# Get the connector's credential pair
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if not cc_pair:
raise ValueError(f"No credential pair found for connector {connector_id}")
# Extract credentials — guard against missing credential_json
cred_json = cc_pair.credential.credential_json
if cred_json is None:
raise ValueError(f"No credential_json for connector {connector_id}")
credentials = cred_json.get_value(apply_mask=False)
if not credentials:
raise ValueError(f"Empty credentials for connector {connector_id}")
# Extract Jira base URL from connector config
connector_config = connector.connector_specific_config or {}
jira_base_url = connector_config.get("jira_base_url", "")
if not jira_base_url:
raise ValueError("Could not determine Jira base URL from connector config")
# Build auth headers
api_token = credentials.get("jira_api_token", "")
email = credentials.get("jira_user_email")
if email:
# Cloud auth: Basic auth with email:token
import base64
auth_string = base64.b64encode(f"{email}:{api_token}".encode()).decode()
auth_headers = {
"Authorization": f"Basic {auth_string}",
"Content-Type": "application/json",
}
else:
# Server auth: Bearer token
auth_headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
}
return jira_base_url, auth_headers
def _count_verdicts(findings: list[ProposalReviewFinding]) -> dict[str, int]:
"""Count findings by verdict."""
counts: dict[str, int] = {
"PASS": 0,
"FAIL": 0,
"FLAG": 0,
"NEEDS_REVIEW": 0,
"NOT_APPLICABLE": 0,
}
for f in findings:
verdict = f.verdict.upper() if f.verdict else "NEEDS_REVIEW"
counts[verdict] = counts.get(verdict, 0) + 1
return counts
def _update_custom_fields(
jira_base_url: str,
auth_headers: dict[str, str],
issue_key: str,
decision: str,
verdict_counts: dict[str, int],
writeback_config: dict,
) -> None:
"""PUT custom fields on the Jira issue (decision, completion %)."""
decision_field = writeback_config.get("decision_field_id")
completion_field = writeback_config.get("completion_field_id")
if not decision_field and not completion_field:
logger.debug("No custom field IDs configured for Jira writeback, skipping")
return
fields: dict = {}
if decision_field:
fields[decision_field] = decision
if completion_field:
total = sum(verdict_counts.values())
completed = total - verdict_counts.get("NEEDS_REVIEW", 0)
pct = (completed / total * 100) if total > 0 else 0
fields[completion_field] = round(pct, 1)
url = f"{jira_base_url}/rest/api/3/issue/{issue_key}"
payload = {"fields": fields}
try:
resp = requests.put(url, headers=auth_headers, json=payload, timeout=30)
resp.raise_for_status()
logger.info(f"Updated custom fields on {issue_key}")
except requests.RequestException as e:
logger.error(f"Failed to update custom fields on {issue_key}: {e}")
raise RuntimeError(f"Jira field update failed: {e}") from e
def _transition_issue(
jira_base_url: str,
auth_headers: dict[str, str],
issue_key: str,
decision: str,
writeback_config: dict,
) -> None:
"""POST a transition to move the issue to the appropriate column."""
transition_map = writeback_config.get("transitions", {})
transition_name = transition_map.get(decision)
if not transition_name:
logger.debug(f"No transition configured for decision '{decision}', skipping")
return
# First, get available transitions
transitions_url = f"{jira_base_url}/rest/api/3/issue/{issue_key}/transitions"
try:
resp = requests.get(transitions_url, headers=auth_headers, timeout=30)
resp.raise_for_status()
available = resp.json().get("transitions", [])
except requests.RequestException as e:
logger.error(f"Failed to fetch transitions for {issue_key}: {e}")
raise RuntimeError(f"Jira transition fetch failed: {e}") from e
# Find the matching transition by name (case-insensitive)
target_transition = None
for t in available:
if t.get("name", "").lower() == transition_name.lower():
target_transition = t
break
if not target_transition:
available_names = [t.get("name", "") for t in available]
logger.warning(
f"Transition '{transition_name}' not found for {issue_key}. "
f"Available: {available_names}"
)
return
# Perform the transition
payload = {"transition": {"id": target_transition["id"]}}
try:
resp = requests.post(
transitions_url, headers=auth_headers, json=payload, timeout=30
)
resp.raise_for_status()
logger.info(f"Transitioned {issue_key} to '{transition_name}'")
except requests.RequestException as e:
logger.error(f"Failed to transition {issue_key}: {e}")
raise RuntimeError(f"Jira transition failed: {e}") from e
def _post_comment(
jira_base_url: str,
auth_headers: dict[str, str],
issue_key: str,
proposal: ProposalReviewProposal,
verdict_counts: dict[str, int],
findings: list[ProposalReviewFinding],
) -> None:
"""POST a structured review summary as a Jira comment."""
comment_text = _build_comment_text(proposal, verdict_counts, findings)
url = f"{jira_base_url}/rest/api/3/issue/{issue_key}/comment"
# Jira Cloud uses ADF (Atlassian Document Format) for comments
payload = {
"body": {
"version": 1,
"type": "doc",
"content": [
{
"type": "paragraph",
"content": [
{
"type": "text",
"text": comment_text,
}
],
}
],
}
}
try:
resp = requests.post(url, headers=auth_headers, json=payload, timeout=30)
resp.raise_for_status()
logger.info(f"Posted review summary comment on {issue_key}")
except requests.RequestException as e:
logger.error(f"Failed to post comment on {issue_key}: {e}")
raise RuntimeError(f"Jira comment post failed: {e}") from e
def _build_comment_text(
proposal: ProposalReviewProposal,
verdict_counts: dict[str, int],
findings: list[ProposalReviewFinding],
) -> str:
"""Build a structured review summary text for the Jira comment."""
lines: list[str] = []
lines.append("=== Proposal Review Summary ===")
lines.append("")
# Decision
decision_text = proposal.status or "N/A"
decision_notes = proposal.decision_notes
lines.append(f"Final Decision: {decision_text}")
if decision_notes:
lines.append(f"Notes: {decision_notes}")
lines.append("")
# Summary counts
total = sum(verdict_counts.values())
lines.append(f"Review Results ({total} rules evaluated):")
lines.append(f" Pass: {verdict_counts.get('PASS', 0)}")
lines.append(f" Fail: {verdict_counts.get('FAIL', 0)}")
lines.append(f" Flag: {verdict_counts.get('FLAG', 0)}")
lines.append(f" Needs Review: {verdict_counts.get('NEEDS_REVIEW', 0)}")
lines.append(f" Not Applicable: {verdict_counts.get('NOT_APPLICABLE', 0)}")
lines.append("")
# Individual findings (truncated for readability)
if findings:
lines.append("--- Detailed Findings ---")
for f in findings:
rule_name = f.rule.name if f.rule else "Unknown Rule"
verdict = f.verdict or "N/A"
officer_action = ""
if f.decision_action:
officer_action = f" | Officer: {f.decision_action}"
lines.append(f" [{verdict}] {rule_name}{officer_action}")
if f.explanation:
# Truncate long explanations
explanation = f.explanation[:200]
if len(f.explanation) > 200:
explanation += "..."
lines.append(f" Reason: {explanation}")
lines.append("")
lines.append(f"Reviewed at: {datetime.now(timezone.utc).isoformat()}")
lines.append("Generated by Onyx Proposal Review")
return "\n".join(lines)

View File

@@ -1,532 +0,0 @@
"""Proposal review engine — private helpers for task implementations.
The actual Celery @shared_task definitions live in tasks.py (for autodiscovery).
This module contains the orchestration and evaluation logic they delegate to.
"""
from __future__ import annotations
import contextvars
import os
import time
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import update
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.server.features.proposal_review.engine.context_assembler import (
ProposalContext,
)
logger = setup_logger()
def _execute_review(
review_run_id: str,
rule_ids: list[str] | None = None,
) -> None:
"""Core review logic, separated for testability.
When rule_ids is None, evaluates all active rules in the run's ruleset
(full run). When rule_ids is provided, deletes the old error findings
for those rules and re-evaluates only them (retry flow).
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import findings as findings_db
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
from onyx.server.features.proposal_review.engine.context_assembler import (
get_proposal_context,
)
from onyx.server.features.proposal_review.engine.foa_fetcher import fetch_foa
run_uuid = UUID(review_run_id)
is_retry = rule_ids is not None
if is_retry and not rule_ids:
logger.warning(f"Retry called with empty rule_ids for run {review_run_id}")
# Reset status since the API already set it to RUNNING
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(run_uuid, db_session)
if run and run.status == "RUNNING":
run.status = "COMPLETED"
db_session.commit()
return
# Step 1: Set run status to RUNNING; for retries, clean up old findings
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(run_uuid, db_session)
if not run:
raise ValueError(f"Review run {review_run_id} not found")
proposal_id = run.proposal_id
ruleset_id = run.ruleset_id
if is_retry:
rule_id_set = set(rule_ids)
# Delete old error findings for the rules being retried
failed = findings_db.get_failed_findings_for_run(run_uuid, db_session)
failed_for_rules = [f for f in failed if str(f.rule_id) in rule_id_set]
if failed_for_rules:
findings_db.delete_findings(
[f.id for f in failed_for_rules], db_session
)
# Roll back counters so re-evaluated rules are tracked correctly.
# completed_rules is rolled back by the number of rules being
# re-evaluated (not findings — a rule may lack a finding if
# _save_error_finding itself failed). failed_rules is rolled back
# by the number of error findings actually deleted.
run.completed_rules = max(0, run.completed_rules - len(rule_ids))
run.failed_rules = max(0, run.failed_rules - len(failed_for_rules))
run.completed_at = None
run.status = "RUNNING"
if not is_retry:
run.started_at = datetime.now(timezone.utc)
db_session.commit()
# Step 2: Assemble proposal context
with get_session_with_current_tenant() as db_session:
context = get_proposal_context(proposal_id, db_session)
# Step 3: Try to auto-fetch FOA if opportunity_id is in metadata
opportunity_id = context.metadata.get("opportunity_id") or context.metadata.get(
"funding_opportunity_number"
)
if opportunity_id and not context.foa_text:
logger.info(f"Attempting to auto-fetch FOA for opportunity_id={opportunity_id}")
try:
with get_session_with_current_tenant() as db_session:
foa_text = fetch_foa(opportunity_id, proposal_id, db_session)
db_session.commit()
if foa_text:
context.foa_text = foa_text
logger.info(f"Auto-fetched FOA: {len(foa_text)} chars")
except Exception as e:
logger.warning(f"FOA auto-fetch failed (non-fatal): {e}")
# Step 4: Determine which rules to evaluate
if is_retry:
# Retry: use the specific rule IDs passed in
rules_to_eval = rule_ids
else:
# Full run: get all active rules for the ruleset
with get_session_with_current_tenant() as db_session:
rules = rulesets_db.list_rules_by_ruleset(
ruleset_id, db_session, active_only=True
)
rules_to_eval = [str(rule.id) for rule in rules]
if not rules_to_eval:
logger.warning(f"No active rules found for ruleset {ruleset_id}")
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(run_uuid, db_session)
if run:
run.status = "COMPLETED"
run.completed_at = datetime.now(timezone.utc)
db_session.commit()
return
# Step 5: Update total_rules on the run (full run only)
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(run_uuid, db_session)
if run:
run.total_rules = len(rules_to_eval)
db_session.commit()
# Step 6: Evaluate rules in parallel via ThreadPoolExecutor
parallel_workers = int(os.environ.get("PROPOSAL_REVIEW_PARALLEL_WORKERS", "4"))
workers = min(parallel_workers, len(rules_to_eval))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_rule_id = {
executor.submit(
contextvars.copy_context().run,
_evaluate_single_rule,
review_run_id,
rid,
proposal_id,
context,
): rid
for rid in rules_to_eval
}
for future in as_completed(future_to_rule_id):
rid = future_to_rule_id[future]
succeeded = True
try:
succeeded = future.result()
except Exception as e:
succeeded = False
logger.error(
f"Rule {rid} failed: {e}",
exc_info=True,
)
# Increment completed_rules (and failed_rules on error) atomically
# so the frontend progress bar always reaches 100%.
updates: dict = {
"completed_rules": ProposalReviewRun.completed_rules + 1,
}
if not succeeded:
updates["failed_rules"] = ProposalReviewRun.failed_rules + 1
with get_session_with_current_tenant() as db_session:
db_session.execute(
update(ProposalReviewRun)
.where(ProposalReviewRun.id == run_uuid)
.values(**updates)
)
db_session.commit()
# Step 7: Mark run as completed
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(run_uuid, db_session)
if run:
run.status = "COMPLETED"
run.completed_at = datetime.now(timezone.utc)
db_session.commit()
logger.info(
f"Review run {review_run_id} completed: "
f"{len(rules_to_eval)} rules evaluated"
f"{' (retry)' if is_retry else ''}"
)
def _evaluate_and_save(
review_run_id: str,
rule_id: str,
proposal_id: "UUID",
context: "ProposalContext",
) -> None:
"""Evaluate a single rule and save the finding to DB."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import findings as findings_db
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
from onyx.server.features.proposal_review.engine.rule_evaluator import (
evaluate_rule,
)
rule_uuid = UUID(rule_id)
run_uuid = UUID(review_run_id)
# Load the rule from DB
with get_session_with_current_tenant() as db_session:
rule = rulesets_db.get_rule(rule_uuid, db_session)
if not rule:
raise ValueError(f"Rule {rule_id} not found")
# Evaluate the rule
result = evaluate_rule(rule, context, db_session)
# Save finding
findings_db.create_finding(
proposal_id=proposal_id,
rule_id=rule_uuid,
review_run_id=run_uuid,
verdict=result["verdict"],
confidence=result.get("confidence"),
evidence=result.get("evidence"),
explanation=result.get("explanation"),
suggested_action=result.get("suggested_action"),
llm_model=result.get("llm_model"),
llm_tokens_used=result.get("llm_tokens_used"),
db_session=db_session,
)
db_session.commit()
logger.debug(f"Rule {rule_id} evaluated: verdict={result['verdict']}")
def _save_error_finding(
review_run_id: str,
rule_id: str,
proposal_id: "UUID",
error: str,
) -> None:
"""Save an error finding when a rule evaluation fails."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import findings as findings_db
try:
with get_session_with_current_tenant() as db_session:
findings_db.create_finding(
proposal_id=proposal_id,
rule_id=UUID(rule_id),
review_run_id=UUID(review_run_id),
verdict="NEEDS_REVIEW",
confidence="LOW",
evidence=None,
explanation=f"Rule evaluation failed with error: {error}",
suggested_action="Manual review required due to system error.",
db_session=db_session,
)
db_session.commit()
except Exception as e:
logger.error(f"Failed to save error finding for rule {rule_id}: {e}")
def _mark_run_failed(review_run_id: str) -> None:
"""Mark a review run as FAILED."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import findings as findings_db
try:
with get_session_with_current_tenant() as db_session:
run = findings_db.get_review_run(UUID(review_run_id), db_session)
if run:
run.status = "FAILED"
run.completed_at = datetime.now(timezone.utc)
db_session.commit()
except Exception as e:
logger.error(f"Failed to mark run {review_run_id} as FAILED: {e}")
_MAX_RULE_RETRIES = int(os.environ.get("PROPOSAL_REVIEW_RULE_MAX_RETRIES", "2"))
_RETRY_BACKOFF_BASE = 2 # seconds — retry waits 2s, 4s, ...
def _evaluate_single_rule(
review_run_id: str,
rule_id: str,
proposal_id: "UUID",
context: "ProposalContext",
) -> bool:
"""Evaluate one rule, save the finding. Called from ThreadPoolExecutor.
Context is shared in-memory from the parent — no DB re-fetch needed.
Retries up to _MAX_RULE_RETRIES times with exponential backoff on failure
(e.g. LLM timeout). On final failure, an error finding (NEEDS_REVIEW) is
saved so the officer sees which rule failed.
Returns True on success, False if all attempts failed.
"""
last_error: Exception | None = None
for attempt in range(_MAX_RULE_RETRIES + 1):
try:
_evaluate_and_save(review_run_id, rule_id, proposal_id, context)
return True
except Exception as e:
last_error = e
if attempt < _MAX_RULE_RETRIES:
wait = _RETRY_BACKOFF_BASE * (2**attempt)
logger.warning(
f"Rule {rule_id} attempt {attempt + 1} failed: {e}. "
f"Retrying in {wait}s..."
)
time.sleep(wait)
else:
logger.error(
f"Rule {rule_id} failed after {attempt + 1} attempts: {e}",
exc_info=True,
)
_save_error_finding(
review_run_id=review_run_id,
rule_id=rule_id,
proposal_id=proposal_id,
error=str(last_error),
)
return False
def _execute_checklist_import(import_job_id: str) -> None:
"""Core import logic, separated for traceability."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.llm.factory import get_default_llm
from onyx.server.features.proposal_review.db import imports as imports_db
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
from onyx.server.features.proposal_review.engine.checklist_importer import (
decompose_checklist_item,
enumerate_checklist_items,
)
job_uuid = UUID(import_job_id)
parallel_workers = int(
os.environ.get("PROPOSAL_REVIEW_IMPORT_PARALLEL_WORKERS", "4")
)
# Step 1: Mark RUNNING and load job data
with get_session_with_current_tenant() as db_session:
job = imports_db.get_import_job(job_uuid, db_session)
if not job:
raise ValueError(f"Import job {import_job_id} not found")
job.status = "RUNNING"
db_session.commit()
ruleset_id = job.ruleset_id
extracted_text = job.extracted_text
llm = get_default_llm(timeout=300)
# Step 2: Enumerate checklist items
items = enumerate_checklist_items(extracted_text, llm)
if not items:
logger.warning(f"Import {import_job_id}: no checklist items found")
with get_session_with_current_tenant() as db_session:
job = imports_db.get_import_job(job_uuid, db_session)
if job:
job.status = "COMPLETED"
job.completed_at = datetime.now(timezone.utc)
db_session.commit()
return
# Split items with too many sub-checks into smaller pieces so each
# LLM call produces bounded output. The threshold is conservative —
# 3 sub-checks keeps output well within token limits.
max_sub_checks = int(
os.environ.get("PROPOSAL_REVIEW_IMPORT_MAX_SUB_CHECKS_PER_CALL", "3")
)
work_items = _split_large_items(items, max_sub_checks)
logger.info(
f"Import {import_job_id}: enumerated {len(items)} items "
f"({len(work_items)} work units after splitting), "
f"decomposing with {parallel_workers} workers"
)
# Step 3: Decompose each work item in parallel, persist as each completes
rules_created = 0
failed_items: list[str] = []
workers = min(parallel_workers, len(work_items))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_item = {
executor.submit(
contextvars.copy_context().run,
decompose_checklist_item,
item,
extracted_text,
llm,
): item
for item in work_items
}
for future in as_completed(future_to_item):
item = future_to_item[future]
try:
rule_dicts = future.result()
except Exception as e:
logger.error(
f" [{item.id}] '{item.name}' failed: {e}",
exc_info=True,
)
failed_items.append(item.name)
continue
if not rule_dicts:
continue
# Persist this item's rules in their own transaction
with get_session_with_current_tenant() as db_session:
for rd in rule_dicts:
rule = rulesets_db.create_rule(
ruleset_id=ruleset_id,
name=rd["name"],
description=rd.get("description"),
category=rd.get("category"),
rule_type=rd.get("rule_type", "CUSTOM_NL"),
rule_intent=rd.get("rule_intent", "CHECK"),
prompt_template=rd["prompt_template"],
source="IMPORTED",
is_hard_stop=False,
priority=0,
refinement_needed=rd.get("refinement_needed", False),
refinement_question=rd.get("refinement_question"),
db_session=db_session,
)
rule.is_active = False
db_session.flush()
rules_created += len(rule_dicts)
# Update progress so the frontend can poll it
job = imports_db.get_import_job(job_uuid, db_session)
if job:
job.rules_created = rules_created
db_session.commit()
logger.info(
f" [{item.id}] '{item.name}': "
f"{len(rule_dicts)} rules persisted "
f"({rules_created} total)"
)
# Step 4: Mark completed
with get_session_with_current_tenant() as db_session:
job = imports_db.get_import_job(job_uuid, db_session)
if job:
job.status = "COMPLETED"
job.completed_at = datetime.now(timezone.utc)
db_session.commit()
status = f"{rules_created} rules created"
if failed_items:
status += f", {len(failed_items)} items failed: {failed_items}"
logger.info(f"Import job {import_job_id} completed: {status}")
def _mark_import_failed(import_job_id: str, error: str) -> None:
"""Mark an import job as FAILED."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import imports as imports_db
try:
with get_session_with_current_tenant() as db_session:
job = imports_db.get_import_job(UUID(import_job_id), db_session)
if job:
job.status = "FAILED"
job.error_message = error
job.completed_at = datetime.now(timezone.utc)
db_session.commit()
except Exception as e:
logger.error(f"Failed to mark import job {import_job_id} as FAILED: {e}")
def _split_large_items(
items: list, # list[ChecklistItem] — untyped to avoid top-level import
max_sub_checks: int,
) -> list:
"""Split checklist items with many sub-checks into smaller work units.
Each returned item has at most *max_sub_checks* sub-checks, keeping the
LLM output bounded regardless of how large the original item was. Items
that are already within the limit pass through unchanged.
"""
from onyx.server.features.proposal_review.engine.checklist_importer import (
ChecklistItem,
)
work_items: list[ChecklistItem] = []
for item in items:
if len(item.sub_checks) <= max_sub_checks:
work_items.append(item)
continue
# Split into batches, each becoming its own work unit
for batch_idx in range(0, len(item.sub_checks), max_sub_checks):
batch = item.sub_checks[batch_idx : batch_idx + max_sub_checks]
part_num = (batch_idx // max_sub_checks) + 1
work_items.append(
ChecklistItem(
id=f"{item.id}-p{part_num}",
name=item.name,
category=item.category,
description=item.description,
sub_checks=batch,
)
)
return work_items

View File

@@ -1,202 +0,0 @@
"""Evaluates a single rule against a proposal context via LLM."""
import json
import re
from sqlalchemy.orm import Session
from onyx.llm.factory import get_default_llm
from onyx.llm.models import SystemMessage
from onyx.llm.models import UserMessage
from onyx.llm.utils import llm_response_to_string
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
from onyx.server.features.proposal_review.engine.context_assembler import (
ProposalContext,
)
from onyx.tracing.llm_utils import llm_generation_span
from onyx.tracing.llm_utils import record_llm_response
from onyx.utils.logger import setup_logger
logger = setup_logger()
SYSTEM_PROMPT = """\
You are a meticulous grant proposal compliance reviewer for a university research office.
Your role is to evaluate specific aspects of grant proposals against institutional
and sponsor requirements.
You must evaluate each rule independently, focusing ONLY on the specific criterion
described. Be precise in your assessment. When in doubt, mark for human review.
Always respond with a valid JSON object in the exact format specified."""
RESPONSE_FORMAT_INSTRUCTIONS = """
Respond with ONLY a valid JSON object in the following format:
{
"verdict": "PASS | FAIL | FLAG | NEEDS_REVIEW | NOT_APPLICABLE",
"confidence": "HIGH | MEDIUM | LOW",
"evidence": "Direct quote or reference from the proposal documents that supports your verdict. If no relevant text found, state that clearly.",
"explanation": "Concise reasoning for why this verdict was reached. Reference specific requirements and how the proposal does or does not meet them.",
"suggested_action": "If verdict is FAIL or FLAG, describe what the officer or PI should do. Otherwise, null."
}
Verdict meanings:
- PASS: The proposal clearly meets this requirement.
- FAIL: The proposal clearly does NOT meet this requirement.
- FLAG: There is a potential issue that needs human attention.
- NEEDS_REVIEW: Insufficient information to make a determination.
- NOT_APPLICABLE: This rule does not apply to this proposal.
"""
def evaluate_rule(
rule: ProposalReviewRule,
context: ProposalContext,
_db_session: Session | None = None,
) -> dict:
"""Evaluate one rule against proposal context via LLM.
1. Fills rule.prompt_template variables ({{proposal_text}}, {{metadata}}, etc.)
2. Wraps in system prompt establishing reviewer role
3. Calls llm.invoke() with structured output instructions
4. Parses response into a findings dict
Args:
rule: The rule to evaluate.
context: Assembled proposal context.
db_session: Optional DB session (not used for LLM call but kept for API compat).
Returns:
Dict with verdict, confidence, evidence, explanation, suggested_action,
plus llm_model and llm_tokens_used if available.
"""
# 1. Fill template variables
filled_prompt = _fill_template(rule.prompt_template, context)
# 2. Build full prompt
user_content = f"{filled_prompt}\n\n" f"{RESPONSE_FORMAT_INSTRUCTIONS}"
prompt_messages = [
SystemMessage(content=SYSTEM_PROMPT),
UserMessage(content=user_content),
]
# 3. Call LLM — exceptions propagate to the caller so the retry
# mechanism in _evaluate_single_rule can handle transient failures.
llm = get_default_llm()
with llm_generation_span(llm, "proposal_review", prompt_messages) as gen_span:
response = llm.invoke(prompt_messages)
record_llm_response(gen_span, response)
raw_text = llm_response_to_string(response)
# Extract model info
llm_model = llm.config.model_name if llm.config else None
llm_tokens_used = _extract_token_usage(response)
# 4. Parse JSON response
result = _parse_llm_response(raw_text)
result["llm_model"] = llm_model
result["llm_tokens_used"] = llm_tokens_used
return result
def _fill_template(template: str, context: ProposalContext) -> str:
"""Replace {{variable}} placeholders in the prompt template.
Supported variables:
- {{proposal_text}} -> context.proposal_text
- {{budget_text}} -> context.budget_text
- {{foa_text}} -> context.foa_text
- {{metadata}} -> JSON dump of context.metadata
- {{metadata.FIELD}} -> specific metadata field value
- {{jira_key}} -> context.jira_key
"""
result = template
# Direct substitutions
result = result.replace("{{proposal_text}}", context.proposal_text or "")
result = result.replace("{{budget_text}}", context.budget_text or "")
result = result.replace("{{foa_text}}", context.foa_text or "")
result = result.replace("{{jira_key}}", context.jira_key or "")
# Metadata as JSON
metadata_str = json.dumps(context.metadata, indent=2, default=str)
result = result.replace("{{metadata}}", metadata_str)
# Specific metadata fields: {{metadata.FIELD}}
metadata_field_pattern = re.compile(r"\{\{metadata\.([^}]+)\}\}")
for match in metadata_field_pattern.finditer(result):
field_name = match.group(1)
field_value = context.metadata.get(field_name, "")
if isinstance(field_value, (dict, list)):
field_value = json.dumps(field_value, default=str)
result = result.replace(match.group(0), str(field_value))
return result
def _parse_llm_response(raw_text: str) -> dict:
"""Parse the LLM response text as JSON.
Handles cases where the LLM wraps JSON in markdown code fences.
"""
text = raw_text.strip()
# Strip markdown code fences if present
if text.startswith("```"):
# Remove opening fence (with optional language tag)
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
# Remove closing fence
text = re.sub(r"\n?```\s*$", "", text)
text = text.strip()
try:
parsed = json.loads(text)
except json.JSONDecodeError:
logger.warning(f"Failed to parse LLM response as JSON: {text[:200]}...")
return {
"verdict": "NEEDS_REVIEW",
"confidence": "LOW",
"evidence": None,
"explanation": f"Failed to parse LLM response. Raw output: {text[:500]}",
"suggested_action": "Manual review required due to unparseable AI response.",
}
# Validate and normalize the parsed result
valid_verdicts = {"PASS", "FAIL", "FLAG", "NEEDS_REVIEW", "NOT_APPLICABLE"}
valid_confidences = {"HIGH", "MEDIUM", "LOW"}
verdict = str(parsed.get("verdict", "NEEDS_REVIEW")).upper()
if verdict not in valid_verdicts:
verdict = "NEEDS_REVIEW"
confidence = str(parsed.get("confidence", "LOW")).upper()
if confidence not in valid_confidences:
confidence = "LOW"
return {
"verdict": verdict,
"confidence": confidence,
"evidence": parsed.get("evidence"),
"explanation": parsed.get("explanation"),
"suggested_action": parsed.get("suggested_action"),
}
def _extract_token_usage(response: object) -> int | None:
"""Best-effort extraction of token usage from the LLM response."""
try:
# litellm ModelResponse has a usage attribute
if hasattr(response, "usage") and response.usage:
usage = response.usage
total = getattr(usage, "total_tokens", None)
if total is not None:
return int(total)
# Sum prompt + completion tokens if total not available
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
if prompt_tokens or completion_tokens:
return prompt_tokens + completion_tokens
except Exception:
pass
return None

View File

@@ -1,195 +0,0 @@
"""Celery tasks for proposal review — discovered by autodiscover_tasks."""
from __future__ import annotations
from uuid import UUID
from celery import shared_task
from redis.lock import Lock as RedisLock
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@shared_task(
name="run_proposal_review",
bind=True,
ignore_result=True,
soft_time_limit=3600,
time_limit=3660,
)
def run_proposal_review(
_self: object,
review_run_id: str,
tenant_id: str,
rule_ids: list[str] | None = None,
) -> None:
"""Evaluate rules for a review run.
When rule_ids is None, evaluates all active rules in the run's ruleset
(full run). When rule_ids is provided, evaluates only those specific
rules (retry flow).
"""
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
from onyx.tracing.framework.create import trace
with trace(
"proposal_review",
metadata={"review_run_id": review_run_id},
):
from onyx.server.features.proposal_review.engine.review_engine import (
_execute_review,
)
_execute_review(review_run_id, rule_ids=rule_ids)
except Exception as e:
logger.error(f"Review run {review_run_id} failed: {e}", exc_info=True)
from onyx.server.features.proposal_review.engine.review_engine import (
_mark_run_failed,
)
_mark_run_failed(review_run_id)
raise
finally:
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@shared_task(name="run_checklist_import", bind=True, ignore_result=True)
def run_checklist_import(_self: object, import_job_id: str, tenant_id: str) -> None:
"""Background task: decompose a checklist via LLM and save rules."""
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
from onyx.tracing.framework.create import trace
with trace(
"checklist_import",
metadata={"import_job_id": import_job_id},
):
from onyx.server.features.proposal_review.engine.review_engine import (
_execute_checklist_import,
)
_execute_checklist_import(import_job_id)
except Exception as e:
logger.error(f"Import job {import_job_id} failed: {e}", exc_info=True)
from onyx.server.features.proposal_review.engine.review_engine import (
_mark_import_failed,
)
_mark_import_failed(import_job_id, str(e))
raise
finally:
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@shared_task(
name="sync_decision_to_jira",
bind=True,
ignore_result=True,
soft_time_limit=60,
time_limit=90,
)
def sync_decision_to_jira(_self: object, proposal_id: str, tenant_id: str) -> None:
"""Writes officer decision back to Jira.
Dispatched from the sync-jira API endpoint.
"""
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.engine.jira_sync import sync_to_jira
with get_session_with_current_tenant() as db_session:
sync_to_jira(UUID(proposal_id), db_session)
db_session.commit()
logger.info(f"Jira sync completed for proposal {proposal_id}")
except Exception as e:
logger.error(f"Jira sync failed for proposal {proposal_id}: {e}", exc_info=True)
raise
finally:
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_DANGLING_IMPORT_JOBS,
bind=True,
ignore_result=True,
soft_time_limit=60,
time_limit=90,
)
def check_for_dangling_import_jobs(_self: object, *, tenant_id: str) -> None:
"""Beat task: mark import jobs stuck in PENDING/RUNNING as FAILED.
A job is considered stuck if it has been in a non-terminal state for
longer than the stale threshold (default 60 minutes). This handles
cases where the Celery message was discarded (e.g. worker restart
before the task was registered) or the task crashed without marking
the job as FAILED.
"""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.features.proposal_review.db import imports as imports_db
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_DANGLING_IMPORT_JOBS_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
logger.info(
f"check_for_dangling_import_jobs - Lock not acquired: tenant={tenant_id}"
)
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
dangling = imports_db.get_dangling_import_jobs(
db_session, stale_threshold_minutes=60
)
if not dangling:
return
for job in dangling:
logger.warning(
f"Marking dangling import job {job.id} as FAILED "
f"(status={job.status}, created_at={job.created_at})"
)
imports_db.mark_import_job_failed(
job,
"Import timed out — the background task did not complete. "
"Please try importing again.",
db_session,
)
db_session.commit()
logger.info(
f"Cleaned up {len(dangling)} dangling import job(s) "
f"for tenant {tenant_id}"
)
except Exception:
logger.exception("Unexpected error during dangling import job cleanup")
finally:
if locked:
if lock.owned():
lock.release()
else:
logger.error(
f"check_for_dangling_import_jobs - "
f"Lock not owned on completion: tenant={tenant_id}"
)
CURRENT_TENANT_ID_CONTEXTVAR.set(None)

View File

@@ -1,8 +1,7 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, queue wait time histograms,
and retry/reject/revoke counts.
active task gauge, task duration histograms, and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
@@ -72,32 +71,6 @@ TASK_REJECTED = Counter(
["task_name"],
)
TASK_QUEUE_WAIT = Histogram(
"onyx_celery_task_queue_wait_seconds",
"Time a Celery task spent waiting in the queue before execution started",
["task_name", "queue"],
buckets=[
0.1,
0.5,
1,
5,
30,
60,
300,
600,
1800,
3600,
7200,
14400,
28800,
43200,
86400,
172800,
432000,
864000,
],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
@@ -160,13 +133,6 @@ def on_celery_task_prerun(
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
headers = getattr(task.request, "headers", None) or {}
enqueued_at = headers.get("enqueued_at")
if isinstance(enqueued_at, (int, float)):
TASK_QUEUE_WAIT.labels(**labels).observe(
max(0.0, time.time() - enqueued_at)
)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)

View File

@@ -1,104 +0,0 @@
"""Connector-deletion-specific Prometheus metrics.
Tracks the deletion lifecycle:
1. Deletions started (taskset generated)
2. Deletions completed (success or failure)
3. Taskset duration (from taskset generation to completion or failure).
Note: this measures the most recent taskset execution, NOT wall-clock
time since the user triggered the deletion. When deletion is blocked by
indexing/pruning/permissions, the fence is cleared and a fresh taskset
is generated on each retry, resetting this timer.
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
5. Fence resets (stuck deletion recovery)
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
to avoid unbounded cardinality.
Usage:
from onyx.server.metrics.deletion_metrics import (
inc_deletion_started,
inc_deletion_completed,
observe_deletion_taskset_duration,
inc_deletion_blocked,
inc_deletion_fence_reset,
)
"""
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
DELETION_STARTED = Counter(
"onyx_deletion_started_total",
"Connector deletions initiated (taskset generated)",
["tenant_id"],
)
DELETION_COMPLETED = Counter(
"onyx_deletion_completed_total",
"Connector deletions completed",
["tenant_id", "outcome"],
)
DELETION_TASKSET_DURATION = Histogram(
"onyx_deletion_taskset_duration_seconds",
"Duration of a connector deletion taskset, from taskset generation "
"to completion or failure. Does not include time spent blocked on "
"indexing/pruning/permissions before the taskset was generated.",
["tenant_id", "outcome"],
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
)
DELETION_BLOCKED = Counter(
"onyx_deletion_blocked_total",
"Times deletion was blocked by a dependency",
["tenant_id", "blocker"],
)
DELETION_FENCE_RESET = Counter(
"onyx_deletion_fence_reset_total",
"Deletion fences reset due to missing celery tasks",
["tenant_id"],
)
def inc_deletion_started(tenant_id: str) -> None:
try:
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion started", exc_info=True)
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
try:
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
except Exception:
logger.debug("Failed to record deletion completed", exc_info=True)
def observe_deletion_taskset_duration(
tenant_id: str, outcome: str, duration_seconds: float
) -> None:
try:
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
duration_seconds
)
except Exception:
logger.debug("Failed to record deletion taskset duration", exc_info=True)
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
try:
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
except Exception:
logger.debug("Failed to record deletion blocked", exc_info=True)
def inc_deletion_fence_reset(tenant_id: str) -> None:
try:
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
except Exception:
logger.debug("Failed to record deletion fence reset", exc_info=True)

View File

@@ -27,7 +27,6 @@ _DEFAULT_PORTS: dict[str, int] = {
"docfetching": 9092,
"docprocessing": 9093,
"heavy": 9094,
"light": 9095,
}
_server_started = False

View File

@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
"onyx_pruning_enumeration_duration_seconds",
"Duration of document ID enumeration from the source connector during pruning",
["connector_type"],
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_DIFF_DURATION = Histogram(
"onyx_pruning_diff_duration_seconds",
"Duration of diff computation and subtask dispatch during pruning",
["connector_type"],
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_RATE_LIMIT_ERRORS = Counter(

View File

@@ -65,8 +65,8 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
multi_model_chat_enabled: bool | None = True
search_ui_enabled: bool | None = True
multi_model_chat_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license

View File

@@ -2,6 +2,7 @@
from collections.abc import Callable
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
@@ -11,8 +12,6 @@ from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
from onyx.db.usage import check_usage_limit
from onyx.db.usage import UsageLimitExceededError
from onyx.db.usage import UsageType
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.tenant_usage_limits import TenantUsageLimitKeys
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
from onyx.utils.logger import setup_logger
@@ -256,14 +255,11 @@ def check_usage_and_raise(
"Please upgrade your plan or wait for the next billing period."
)
elif usage_type == UsageType.API_CALLS:
if is_trial and e.limit == 0:
detail = "API access is not available on trial accounts. Please upgrade to a paid plan to use the API and chat widget."
else:
detail = (
f"API call limit exceeded for {user_type} account. "
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
"Please upgrade your plan or wait for the next billing period."
)
detail = (
f"API call limit exceeded for {user_type} account. "
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
"Please upgrade your plan or wait for the next billing period."
)
else:
detail = (
f"Non-streaming API call limit exceeded for {user_type} account. "
@@ -271,4 +267,4 @@ def check_usage_and_raise(
"Please upgrade your plan or wait for the next billing period."
)
raise OnyxError(OnyxErrorCode.RATE_LIMITED, detail)
raise HTTPException(status_code=429, detail=detail)

View File

@@ -10,7 +10,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -114,10 +113,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
def construct_tools(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
db_session: Session | None = None,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
@@ -132,33 +131,6 @@ def construct_tools(
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
with get_session_with_current_tenant_if_none(db_session) as db_session:
return _construct_tools_impl(
persona=persona,
db_session=db_session,
emitter=emitter,
user=user,
llm=llm,
search_tool_config=search_tool_config,
custom_tool_config=custom_tool_config,
file_reader_tool_config=file_reader_tool_config,
allowed_tool_ids=allowed_tool_ids,
search_usage_forcing_setting=search_usage_forcing_setting,
)
def _construct_tools_impl(
persona: Persona,
db_session: Session,
emitter: Emitter,
user: User,
llm: LLM,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
allowed_tool_ids: list[int] | None = None,
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
) -> dict[int, list[Tool]]:
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging

View File

@@ -17,7 +17,6 @@ def documents_to_indexing_documents(
processed_sections = []
for section in document.sections:
processed_section = Section(
type=section.type,
text=section.text or "",
link=section.link,
image_file_id=None,

View File

@@ -26,7 +26,7 @@ aiolimiter==1.2.1
# via voyageai
aiosignal==1.4.0
# via aiohttp
alembic==1.18.4
alembic==1.10.4
amqp==5.3.1
# via kombu
annotated-doc==0.0.4
@@ -174,7 +174,7 @@ coloredlogs==15.0.1
# via onnxruntime
courlan==1.3.2
# via trafilatura
cryptography==46.0.7
cryptography==46.0.6
# via
# authlib
# google-auth
@@ -408,7 +408,7 @@ kombu==5.5.4
# via celery
kubernetes==31.0.0
# via onyx
langchain-core==1.2.28
langchain-core==1.2.22
langdetect==1.0.9
# via unstructured
langfuse==3.10.0
@@ -583,7 +583,7 @@ pathable==0.4.4
# via jsonschema-path
pdfminer-six==20251107
# via markitdown
pillow==12.2.0
pillow==12.1.1
# via python-pptx
platformdirs==4.5.0
# via
@@ -666,9 +666,7 @@ pyee==13.0.0
# via playwright
pygithub==2.5.0
pygments==2.20.0
# via
# pytest
# rich
# via rich
pyjwt==2.12.0
# via
# fastapi-users
@@ -682,13 +680,13 @@ pynacl==1.6.2
pypandoc-binary==1.16.2
pyparsing==3.2.5
# via httplib2
pypdf==6.10.0
pypdf==6.9.2
# via unstructured-client
pyperclip==1.11.0
# via fastmcp
pyreadline3==3.5.4 ; sys_platform == 'win32'
# via humanfriendly
pytest==9.0.3
pytest==8.3.5
# via
# pytest-base-url
# pytest-mock
@@ -696,7 +694,7 @@ pytest==9.0.3
pytest-base-url==2.1.0
# via pytest-playwright
pytest-mock==3.12.0
pytest-playwright==0.7.2
pytest-playwright==0.7.0
python-dateutil==2.8.2
# via
# aiobotocore

View File

@@ -22,7 +22,7 @@ aiolimiter==1.2.1
# via voyageai
aiosignal==1.4.0
# via aiohttp
alembic==1.18.4
alembic==1.10.4
# via pytest-alembic
annotated-doc==0.0.4
# via fastapi
@@ -46,7 +46,7 @@ attrs==25.4.0
# aiohttp
# jsonschema
# referencing
black==26.3.1
black==25.1.0
boto3==1.39.11
# via
# aiobotocore
@@ -95,7 +95,7 @@ comm==0.2.3
# via ipykernel
contourpy==1.3.3
# via matplotlib
cryptography==46.0.7
cryptography==46.0.6
# via
# google-auth
# pyjwt
@@ -254,7 +254,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.7.5
onyx-devtools==0.7.3
openai==2.14.0
# via
# litellm
@@ -274,13 +274,13 @@ parameterized==0.9.0
# via cohere
parso==0.8.5
# via jedi
pathspec==1.0.4
pathspec==0.12.1
# via
# black
# hatchling
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
# via ipython
pillow==12.2.0
pillow==12.1.1
# via matplotlib
platformdirs==4.5.0
# via
@@ -339,12 +339,11 @@ pygments==2.20.0
# via
# ipython
# ipython-pygments-lexers
# pytest
pyjwt==2.12.0
# via mcp
pyparsing==3.2.5
# via matplotlib
pytest==9.0.3
pytest==8.3.5
# via
# pytest-alembic
# pytest-asyncio
@@ -370,8 +369,6 @@ python-dotenv==1.1.1
# pytest-dotenv
python-multipart==0.0.22
# via mcp
pytokens==0.4.1
# via black
pywin32==311 ; sys_platform == 'win32'
# via mcp
pyyaml==6.0.3

View File

@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.7
cryptography==46.0.6
# via
# google-auth
# pyjwt

View File

@@ -91,7 +91,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# tqdm
cryptography==46.0.7
cryptography==46.0.6
# via
# google-auth
# pyjwt
@@ -264,7 +264,7 @@ packaging==24.2
# transformers
parameterized==0.9.0
# via cohere
pillow==12.2.0
pillow==12.1.1
# via sentence-transformers
prometheus-client==0.23.1
# via

View File

@@ -1,239 +0,0 @@
"""Tests for GoogleDriveConnector.resolve_errors against real Google Drive."""
import json
import os
from collections.abc import Callable
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import (
ALL_EXPECTED_HIERARCHY_NODES,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
_DRIVE_ID_MAPPING_PATH = os.path.join(
os.path.dirname(__file__), "drive_id_mapping.json"
)
def _load_web_view_links(file_ids: list[int]) -> list[str]:
with open(_DRIVE_ID_MAPPING_PATH) as f:
mapping: dict[str, str] = json.load(f)
return [mapping[str(fid)] for fid in file_ids]
def _build_failures(web_view_links: list[str]) -> list[ConnectorFailure]:
return [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=link,
document_link=link,
),
failure_message=f"Synthetic failure for {link}",
)
for link in web_view_links
]
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_single_file(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve a single known file and verify we get back exactly one Document."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
web_view_links = _load_web_view_links([0])
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
assert len(docs) == 1
assert len(new_failures) == 0
assert docs[0].semantic_identifier == "file_0.txt"
# Should yield at least one hierarchy node (the file's parent folder chain)
assert len(hierarchy_nodes) > 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_multiple_files(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve multiple files across different folders via batch API."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
# Pick files from different folders: admin files (0-4), shared drive 1 (20-24), folder_2 (45-49)
file_ids = [0, 1, 20, 21, 45]
web_view_links = _load_web_view_links(file_ids)
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
assert len(new_failures) == 0
retrieved_names = {doc.semantic_identifier for doc in docs}
expected_names = {f"file_{fid}.txt" for fid in file_ids}
assert expected_names == retrieved_names
# Files span multiple folders, so we should get hierarchy nodes
assert len(hierarchy_nodes) > 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_hierarchy_nodes_are_valid(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Verify that hierarchy nodes from resolve_errors match expected structure."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
# File in folder_1 (inside shared_drive_1) — should walk up to shared_drive_1 root
web_view_links = _load_web_view_links([25])
failures = _build_failures(web_view_links)
results = list(connector.resolve_errors(failures))
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
node_ids = {node.raw_node_id for node in hierarchy_nodes}
# File 25 is in folder_1 which is inside shared_drive_1.
# The parent walk must yield at least these two ancestors.
assert (
FOLDER_1_ID in node_ids
), f"Expected folder_1 ({FOLDER_1_ID}) in hierarchy nodes, got: {node_ids}"
assert (
SHARED_DRIVE_1_ID in node_ids
), f"Expected shared_drive_1 ({SHARED_DRIVE_1_ID}) in hierarchy nodes, got: {node_ids}"
for node in hierarchy_nodes:
if node.raw_node_id not in ALL_EXPECTED_HIERARCHY_NODES:
continue
expected = ALL_EXPECTED_HIERARCHY_NODES[node.raw_node_id]
assert node.display_name == expected.display_name, (
f"Display name mismatch for {node.raw_node_id}: "
f"expected '{expected.display_name}', got '{node.display_name}'"
)
assert node.node_type == expected.node_type, (
f"Node type mismatch for {node.raw_node_id}: "
f"expected '{expected.node_type}', got '{node.node_type}'"
)
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_with_invalid_link(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolve with a mix of valid and invalid links — invalid ones yield ConnectorFailure."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
valid_links = _load_web_view_links([0])
invalid_link = "https://drive.google.com/file/d/NONEXISTENT_FILE_ID_12345"
failures = _build_failures(valid_links + [invalid_link])
results = list(connector.resolve_errors(failures))
docs = [r for r in results if isinstance(r, Document)]
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
assert len(docs) == 1
assert docs[0].semantic_identifier == "file_0.txt"
assert len(new_failures) == 1
assert new_failures[0].failed_document is not None
assert new_failures[0].failed_document.document_id == invalid_link
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_empty_errors(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Resolving an empty error list should yield nothing."""
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
results = list(connector.resolve_errors([]))
assert len(results) == 0
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
def test_resolve_entity_failures_are_skipped(
mock_api_key: None, # noqa: ARG001
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
"""Entity failures (not document failures) should be skipped by resolve_errors."""
from onyx.connectors.models import EntityFailure
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=False,
)
entity_failure = ConnectorFailure(
failed_entity=EntityFailure(entity_id="some_stage"),
failure_message="retrieval failure",
)
results = list(connector.resolve_errors([entity_failure]))
assert len(results) == 0

View File

@@ -1,85 +0,0 @@
"""Shared fixtures for proposal review integration tests.
Uses the same real-PostgreSQL pattern as the parent external_dependency_unit
conftest. Tables must already exist (via the 61ea78857c97 migration).
"""
from collections.abc import Generator
from uuid import uuid4
import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.models import UserRole
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# Tables to clean up after each test, in dependency order (children first).
_PROPOSAL_REVIEW_TABLES = [
"proposal_review_finding",
"proposal_review_run",
"proposal_review_document",
"proposal_review_proposal",
"proposal_review_rule",
"proposal_review_import_job",
"proposal_review_ruleset",
"proposal_review_config",
]
@pytest.fixture(scope="function")
def tenant_context() -> Generator[None, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture(scope="function")
def db_session(tenant_context: None) -> Generator[Session, None, None]: # noqa: ARG001
"""Yield a DB session scoped to the current tenant.
After the test completes, all proposal_review rows are deleted so tests
don't leave artifacts in the database.
"""
SqlEngine.init_engine(pool_size=10, max_overflow=5)
with get_session_with_current_tenant() as session:
yield session
# Clean up all proposal_review data created during this test
try:
for table in _PROPOSAL_REVIEW_TABLES:
session.execute(text(f"DELETE FROM {table}")) # noqa: S608
session.commit()
except Exception:
session.rollback()
@pytest.fixture(scope="function")
def test_user(db_session: Session) -> User:
"""Create a throwaway user for FK references (triggered_by, officer_id, etc.)."""
unique_email = f"pr_test_{uuid4().hex[:8]}@example.com"
password_helper = PasswordHelper()
hashed_password = password_helper.hash(password_helper.generate())
user = User(
id=uuid4(),
email=unique_email,
hashed_password=hashed_password,
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
account_type=AccountType.STANDARD,
)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user

View File

@@ -1,328 +0,0 @@
"""Integration tests for per-finding decisions, proposal decisions, and config."""
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.server.features.proposal_review.db.config import get_config
from onyx.server.features.proposal_review.db.config import upsert_config
from onyx.server.features.proposal_review.db.decisions import (
mark_proposal_jira_synced,
)
from onyx.server.features.proposal_review.db.decisions import (
update_proposal_decision,
)
from onyx.server.features.proposal_review.db.decisions import (
upsert_finding_decision,
)
from onyx.server.features.proposal_review.db.findings import create_finding
from onyx.server.features.proposal_review.db.findings import create_review_run
from onyx.server.features.proposal_review.db.findings import get_finding
from onyx.server.features.proposal_review.db.proposals import get_or_create_proposal
from onyx.server.features.proposal_review.db.rulesets import create_rule
from onyx.server.features.proposal_review.db.rulesets import create_ruleset
from tests.external_dependency_unit.constants import TEST_TENANT_ID
TENANT = TEST_TENANT_ID
def _make_finding(db_session: Session, test_user: User):
"""Helper: create a full chain (ruleset -> rule -> proposal -> run -> finding)."""
rs = create_ruleset(
tenant_id=TENANT,
name=f"RS-{uuid4().hex[:6]}",
db_session=db_session,
created_by=test_user.id,
)
rule = create_rule(
ruleset_id=rs.id,
name="Test Rule",
rule_type="DOCUMENT_CHECK",
prompt_template="{{proposal_text}}",
db_session=db_session,
)
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
run = create_review_run(
proposal_id=proposal.id,
ruleset_id=rs.id,
triggered_by=test_user.id,
total_rules=1,
db_session=db_session,
)
finding = create_finding(
proposal_id=proposal.id,
rule_id=rule.id,
review_run_id=run.id,
verdict="FAIL",
db_session=db_session,
)
db_session.commit()
return finding, proposal
class TestFindingDecision:
def test_create_finding_decision(
self, db_session: Session, test_user: User
) -> None:
finding, _ = _make_finding(db_session, test_user)
updated = upsert_finding_decision(
finding_id=finding.id,
officer_id=test_user.id,
action="VERIFIED",
db_session=db_session,
notes="Looks good",
)
db_session.commit()
assert updated.id == finding.id
assert updated.decision_action == "VERIFIED"
assert updated.decision_notes == "Looks good"
assert updated.decided_at is not None
def test_upsert_overwrites_previous_decision(
self, db_session: Session, test_user: User
) -> None:
finding, _ = _make_finding(db_session, test_user)
upsert_finding_decision(
finding_id=finding.id,
officer_id=test_user.id,
action="VERIFIED",
db_session=db_session,
)
db_session.commit()
updated = upsert_finding_decision(
finding_id=finding.id,
officer_id=test_user.id,
action="ISSUE",
db_session=db_session,
notes="Actually, this is a problem",
)
db_session.commit()
# Same row was updated
assert updated.id == finding.id
assert updated.decision_action == "ISSUE"
assert updated.decision_notes == "Actually, this is a problem"
def test_finding_decision_accessible_via_finding(
self, db_session: Session, test_user: User
) -> None:
finding, _ = _make_finding(db_session, test_user)
upsert_finding_decision(
finding_id=finding.id,
officer_id=test_user.id,
action="OVERRIDDEN",
db_session=db_session,
)
db_session.commit()
fetched = get_finding(finding.id, db_session)
assert fetched is not None
assert fetched.decision_action == "OVERRIDDEN"
assert fetched.decision_officer_id == test_user.id
def test_finding_has_no_decision_by_default(
self, db_session: Session, test_user: User
) -> None:
finding, _ = _make_finding(db_session, test_user)
assert finding.decision_action is None
assert finding.decided_at is None
class TestProposalDecision:
def test_update_proposal_decision(
self, db_session: Session, test_user: User
) -> None:
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
db_session.commit()
updated = update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="APPROVED",
db_session=db_session,
notes="All checks pass",
)
db_session.commit()
assert updated.status == "APPROVED"
assert updated.decision_notes == "All checks pass"
assert updated.decision_officer_id == test_user.id
assert updated.decision_at is not None
assert updated.jira_synced is False
def test_proposal_decision_overwrites_previous(
self, db_session: Session, test_user: User
) -> None:
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
db_session.commit()
update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="CHANGES_REQUESTED",
db_session=db_session,
)
db_session.commit()
updated = update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="APPROVED",
db_session=db_session,
)
db_session.commit()
assert updated.status == "APPROVED"
def test_mark_proposal_jira_synced(
self, db_session: Session, test_user: User
) -> None:
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="APPROVED",
db_session=db_session,
)
db_session.commit()
assert proposal.jira_synced is False
synced = mark_proposal_jira_synced(proposal.id, TENANT, db_session)
db_session.commit()
assert synced is not None
assert synced.jira_synced is True
assert synced.jira_synced_at is not None
def test_mark_jira_synced_returns_none_for_nonexistent(
self, db_session: Session
) -> None:
assert mark_proposal_jira_synced(uuid4(), TENANT, db_session) is None
def test_new_decision_resets_jira_synced(
self, db_session: Session, test_user: User
) -> None:
"""Re-deciding should reset jira_synced so the new decision can be synced."""
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="APPROVED",
db_session=db_session,
)
mark_proposal_jira_synced(proposal.id, TENANT, db_session)
db_session.commit()
assert proposal.jira_synced is True
update_proposal_decision(
proposal_id=proposal.id,
tenant_id=TENANT,
officer_id=test_user.id,
decision="REJECTED",
db_session=db_session,
)
db_session.commit()
db_session.refresh(proposal)
assert proposal.jira_synced is False
assert proposal.jira_synced_at is None
assert proposal.status == "REJECTED"
class TestConfig:
def test_create_config(self, db_session: Session) -> None:
# Use a unique tenant to avoid collision with other tests
tenant = f"test-tenant-{uuid4().hex[:8]}"
config = upsert_config(
tenant_id=tenant,
db_session=db_session,
jira_project_key="PROJ",
field_mapping={"title": "summary", "budget": "customfield_10001"},
)
db_session.commit()
assert config.id is not None
assert config.tenant_id == tenant
assert config.jira_project_key == "PROJ"
assert config.field_mapping == {
"title": "summary",
"budget": "customfield_10001",
}
def test_upsert_config_updates_existing(self, db_session: Session) -> None:
tenant = f"test-tenant-{uuid4().hex[:8]}"
first = upsert_config(
tenant_id=tenant,
db_session=db_session,
jira_project_key="OLD",
)
db_session.commit()
first_id = first.id
second = upsert_config(
tenant_id=tenant,
db_session=db_session,
jira_project_key="NEW",
field_mapping={"x": "y"},
)
db_session.commit()
assert second.id == first_id
assert second.jira_project_key == "NEW"
assert second.field_mapping == {"x": "y"}
def test_get_config_returns_correct_tenant(self, db_session: Session) -> None:
tenant = f"test-tenant-{uuid4().hex[:8]}"
upsert_config(
tenant_id=tenant,
db_session=db_session,
jira_project_key="ABC",
jira_writeback={"status_field": "customfield_20001"},
)
db_session.commit()
fetched = get_config(tenant, db_session)
assert fetched is not None
assert fetched.jira_project_key == "ABC"
assert fetched.jira_writeback == {"status_field": "customfield_20001"}
def test_get_config_returns_none_for_unknown_tenant(
self, db_session: Session
) -> None:
assert get_config(f"nonexistent-{uuid4().hex[:8]}", db_session) is None
def test_upsert_config_preserves_unset_fields(self, db_session: Session) -> None:
tenant = f"test-tenant-{uuid4().hex[:8]}"
upsert_config(
tenant_id=tenant,
db_session=db_session,
jira_project_key="KEEP",
jira_connector_id=42,
)
db_session.commit()
# Update only field_mapping, leave jira_project_key alone
upsert_config(
tenant_id=tenant,
db_session=db_session,
field_mapping={"a": "b"},
)
db_session.commit()
fetched = get_config(tenant, db_session)
assert fetched is not None
assert fetched.jira_project_key == "KEEP"
assert fetched.jira_connector_id == 42
assert fetched.field_mapping == {"a": "b"}

Some files were not shown because too many files have changed in this diff Show More