mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-15 03:42:52 +00:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f87e03b194 | ||
|
|
873636a095 | ||
|
|
efb194e067 | ||
|
|
3f7dfa7813 | ||
|
|
5f08af3678 | ||
|
|
1243af4f86 | ||
|
|
91e84b8278 | ||
|
|
1d6baf10db | ||
|
|
8d26357197 | ||
|
|
cd43345415 | ||
|
|
f99cf2f1b0 | ||
|
|
7332adb1e6 | ||
|
|
0ab1b76765 | ||
|
|
40cd0a78a3 | ||
|
|
28d8c5de46 | ||
|
|
004092767f | ||
|
|
eb4689a669 | ||
|
|
47dd8973c1 | ||
|
|
a1403ef78c | ||
|
|
f96b9d6804 | ||
|
|
711651276c | ||
|
|
3731110cf9 | ||
|
|
8fb7a8718e |
@@ -1,62 +0,0 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
@@ -1,86 +0,0 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure the active
|
||||
user has read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. `ods dev up` auto-detects rootless Docker
|
||||
and sets `DEVCONTAINER_REMOTE_USER=root` so the container runs as root — which
|
||||
maps back to your host user via the user namespace. New files are owned by your
|
||||
host UID and no ACL workarounds are needed.
|
||||
|
||||
To override the auto-detection, set `DEVCONTAINER_REMOTE_USER` before running
|
||||
`ods dev up`.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"containerEnv": {
|
||||
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
|
||||
},
|
||||
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
|
||||
"updateRemoteUserUID": false,
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. Requires
|
||||
# DEVCONTAINER_REMOTE_USER=root (set automatically by
|
||||
# ods dev up). Container root IS the host user, so
|
||||
# bind-mounts and named volumes are symlinked into /root.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
REMOTE_USER="${SUDO_USER:-$TARGET_USER}"
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
# devcontainer.json bind-mounts and named volumes target /home/dev regardless
|
||||
# of remoteUser. When running as root ($HOME=/root), Phase 1 bridges the gap
|
||||
# with symlinks from ACTIVE_HOME → MOUNT_HOME.
|
||||
MOUNT_HOME=/home/"$TARGET_USER"
|
||||
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
ACTIVE_HOME="/root"
|
||||
else
|
||||
ACTIVE_HOME="$MOUNT_HOME"
|
||||
fi
|
||||
|
||||
# ── Phase 1: home directory setup ───────────────────────────────────
|
||||
|
||||
# ~/.local and ~/.cache are named Docker volumes mounted under MOUNT_HOME.
|
||||
mkdir -p "$MOUNT_HOME"/.local/state "$MOUNT_HOME"/.local/share
|
||||
|
||||
# When running as root, symlink bind-mounts and named volumes into /root
|
||||
# so that $HOME-relative tools (Claude Code, git, etc.) find them.
|
||||
if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
for item in .claude .cache .local; do
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
# Symlink files (not directories).
|
||||
for file in .claude.json .gitconfig .zshrc.host; do
|
||||
[ -f "$MOUNT_HOME/$file" ] && ln -sf "$MOUNT_HOME/$file" "$ACTIVE_HOME/$file"
|
||||
done
|
||||
|
||||
# Nested mount: .config/nvim
|
||||
if [ -d "$MOUNT_HOME/.config/nvim" ]; then
|
||||
mkdir -p "$ACTIVE_HOME/.config"
|
||||
if [ -e "$ACTIVE_HOME/.config/nvim" ] && [ ! -L "$ACTIVE_HOME/.config/nvim" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/.config/nvim with symlink" >&2
|
||||
rm -rf "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/.config/nvim" "$ACTIVE_HOME/.config/nvim"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Phase 2: workspace access ───────────────────────────────────────
|
||||
|
||||
# Root always has workspace access; Phase 1 handled home setup.
|
||||
if [ "$REMOTE_USER" = "root" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" "$MOUNT_HOME" 2>&1; then
|
||||
echo "warning: failed to chown $MOUNT_HOME" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned (UID 0) due to user-namespace mapping.
|
||||
# The supported path is remoteUser=root (set DEVCONTAINER_REMOTE_USER=root),
|
||||
# which is handled above. If we reach here, the user is running as dev
|
||||
# under rootless Docker without the override.
|
||||
echo "error: rootless Docker detected but remoteUser is not root." >&2
|
||||
echo " Set DEVCONTAINER_REMOTE_USER=root before starting the container," >&2
|
||||
echo " or use 'ods dev up' which sets it automatically." >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -1,105 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Preserve docker dns resolution
|
||||
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
|
||||
|
||||
# Flush all rules
|
||||
iptables -t nat -F
|
||||
iptables -t nat -X
|
||||
iptables -t mangle -F
|
||||
iptables -t mangle -X
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Restore docker dns rules
|
||||
if [ -n "$DOCKER_DNS_RULES" ]; then
|
||||
echo "$DOCKER_DNS_RULES" | iptables-restore -n
|
||||
fi
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Allow traffic to the Docker gateway so the container can reach host services
|
||||
# (e.g. the Onyx stack at localhost:3000, localhost:8080, etc.)
|
||||
DOCKER_GATEWAY=$(ip -4 route show default | awk '{print $3}')
|
||||
if [ -n "$DOCKER_GATEWAY" ]; then
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
@@ -1,10 +0,0 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
4
.github/workflows/deployment.yml
vendored
4
.github/workflows/deployment.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -114,7 +114,7 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -471,7 +471,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -710,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
2
.github/workflows/release-cli.yml
vendored
2
.github/workflows/release-cli.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # ratchet:astral-sh/setup-uv@v8.0.0
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -1,57 +1,64 @@
|
||||
{
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": ["logic", "syntax", "style"],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": ["dependabot[bot]", "renovate[bot]"],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": false,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
"labels": [],
|
||||
"comment": "",
|
||||
"fixWithAI": true,
|
||||
"hideFooter": false,
|
||||
"strictness": 3,
|
||||
"statusCheck": true,
|
||||
"commentTypes": [
|
||||
"logic",
|
||||
"syntax",
|
||||
"style"
|
||||
],
|
||||
"instructions": "",
|
||||
"disabledLabels": [],
|
||||
"excludeAuthors": [
|
||||
"dependabot[bot]",
|
||||
"renovate[bot]"
|
||||
],
|
||||
"ignoreKeywords": "",
|
||||
"ignorePatterns": "",
|
||||
"includeAuthors": [],
|
||||
"summarySection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
"excludeBranches": [],
|
||||
"fileChangeLimit": 300,
|
||||
"includeBranches": [],
|
||||
"includeKeywords": "",
|
||||
"triggerOnUpdates": true,
|
||||
"updateExistingSummaryComment": true,
|
||||
"updateSummaryOnly": false,
|
||||
"issuesTableSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
"statusCommentsEnabled": true,
|
||||
"confidenceScoreSection": {
|
||||
"included": true,
|
||||
"collapsible": false
|
||||
},
|
||||
"sequenceDiagramSection": {
|
||||
"included": true,
|
||||
"collapsible": false,
|
||||
"defaultOpen": false
|
||||
},
|
||||
"shouldUpdateDescription": false,
|
||||
"rules": [
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs."
|
||||
},
|
||||
{
|
||||
"scope": ["web/**"],
|
||||
"rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -49,12 +49,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa metadata sync, connector deletion, doc permissions upsert, checkpoint cleanup, index attempt cleanup
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Tasks: connector pruning, document permissions sync, external group sync, CSV generation
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
|
||||
@@ -208,7 +208,7 @@ def do_run_migrations(
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -380,7 +380,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -421,7 +421,7 @@ def run_migrations_offline() -> None:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
@@ -464,7 +464,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -25,7 +25,7 @@ def upgrade() -> None:
|
||||
|
||||
# Use batch mode to modify the enum type
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC",
|
||||
@@ -71,7 +71,7 @@ def downgrade() -> None:
|
||||
op.drop_column("user__user_group", "is_curator")
|
||||
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
batch_op.alter_column( # type: ignore[attr-defined]
|
||||
"role",
|
||||
type_=sa.Enum(
|
||||
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
|
||||
|
||||
@@ -1,499 +0,0 @@
|
||||
"""add proposal review tables
|
||||
|
||||
Revision ID: 61ea78857c97
|
||||
Revises: d129f37b3d87
|
||||
Create Date: 2026-04-09 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ea78857c97"
|
||||
down_revision = "d129f37b3d87"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# -- proposal_review_ruleset --
|
||||
op.create_table(
|
||||
"proposal_review_ruleset",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["created_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_ruleset_tenant_id",
|
||||
"proposal_review_ruleset",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_rule --
|
||||
op.create_table(
|
||||
"proposal_review_rule",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("category", sa.Text(), nullable=True),
|
||||
sa.Column("rule_type", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"rule_intent",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'CHECK'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("prompt_template", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'MANUAL'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("authority", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"is_hard_stop",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"priority",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"refinement_needed",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("refinement_question", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_rule_ruleset_id",
|
||||
"proposal_review_rule",
|
||||
["ruleset_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_proposal --
|
||||
# Includes inline proposal-level decision fields (no separate decision table).
|
||||
op.create_table(
|
||||
"proposal_review_proposal",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("document_id", sa.Text(), nullable=False),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
# Inline proposal-level decision fields
|
||||
sa.Column("decision_notes", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"decision_officer_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("decision_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"jira_synced",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("jira_synced_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("document_id", "tenant_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_tenant_id",
|
||||
"proposal_review_proposal",
|
||||
["tenant_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_document_id",
|
||||
"proposal_review_proposal",
|
||||
["document_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_proposal_status",
|
||||
"proposal_review_proposal",
|
||||
["status"],
|
||||
)
|
||||
|
||||
# -- proposal_review_run --
|
||||
op.create_table(
|
||||
"proposal_review_run",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"triggered_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("total_rules", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"completed_rules",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["triggered_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_run_proposal_id",
|
||||
"proposal_review_run",
|
||||
["proposal_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_finding --
|
||||
# Includes inline per-finding decision fields (no separate decision table).
|
||||
op.create_table(
|
||||
"proposal_review_finding",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"rule_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"review_run_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("verdict", sa.Text(), nullable=False),
|
||||
sa.Column("confidence", sa.Text(), nullable=True),
|
||||
sa.Column("evidence", sa.Text(), nullable=True),
|
||||
sa.Column("explanation", sa.Text(), nullable=True),
|
||||
sa.Column("suggested_action", sa.Text(), nullable=True),
|
||||
sa.Column("llm_model", sa.Text(), nullable=True),
|
||||
sa.Column("llm_tokens_used", sa.Integer(), nullable=True),
|
||||
# Inline per-finding decision fields
|
||||
sa.Column("decision_action", sa.Text(), nullable=True),
|
||||
sa.Column("decision_notes", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"decision_officer_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("decided_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["rule_id"],
|
||||
["proposal_review_rule.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["review_run_id"],
|
||||
["proposal_review_run.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["decision_officer_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_proposal_id",
|
||||
"proposal_review_finding",
|
||||
["proposal_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_review_run_id",
|
||||
"proposal_review_finding",
|
||||
["review_run_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_finding_rule_id",
|
||||
"proposal_review_finding",
|
||||
["rule_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_document --
|
||||
op.create_table(
|
||||
"proposal_review_document",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"proposal_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("file_name", sa.Text(), nullable=False),
|
||||
sa.Column("file_type", sa.Text(), nullable=True),
|
||||
sa.Column("file_store_id", sa.Text(), nullable=True),
|
||||
sa.Column("extracted_text", sa.Text(), nullable=True),
|
||||
sa.Column("document_role", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"uploaded_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["proposal_id"],
|
||||
["proposal_review_proposal.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["uploaded_by"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_document_proposal_id",
|
||||
"proposal_review_document",
|
||||
["proposal_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_import_job --
|
||||
op.create_table(
|
||||
"proposal_review_import_job",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"ruleset_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Text(),
|
||||
server_default=sa.text("'PENDING'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("source_filename", sa.Text(), nullable=False),
|
||||
sa.Column("extracted_text", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"rules_created",
|
||||
sa.Integer(),
|
||||
server_default=sa.text("0"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["ruleset_id"],
|
||||
["proposal_review_ruleset.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_proposal_review_import_job_ruleset_id",
|
||||
"proposal_review_import_job",
|
||||
["ruleset_id"],
|
||||
)
|
||||
|
||||
# -- proposal_review_config --
|
||||
op.create_table(
|
||||
"proposal_review_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("tenant_id", sa.Text(), nullable=False, unique=True),
|
||||
sa.Column("jira_connector_id", sa.Integer(), nullable=True),
|
||||
sa.Column("jira_project_key", sa.Text(), nullable=True),
|
||||
sa.Column("field_mapping", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("jira_writeback", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("review_model", sa.Text(), nullable=True),
|
||||
sa.Column("import_model", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("proposal_review_import_job")
|
||||
op.drop_table("proposal_review_config")
|
||||
op.drop_table("proposal_review_document")
|
||||
op.drop_table("proposal_review_finding")
|
||||
op.drop_table("proposal_review_run")
|
||||
op.drop_table("proposal_review_proposal")
|
||||
op.drop_table("proposal_review_rule")
|
||||
op.drop_table("proposal_review_ruleset")
|
||||
@@ -63,7 +63,7 @@ def upgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("now()"),
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
)
|
||||
op.alter_column(
|
||||
"index_attempt",
|
||||
@@ -85,7 +85,7 @@ def downgrade() -> None:
|
||||
"time_created",
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("now()"),
|
||||
existing_server_default=sa.text("now()"), # type: ignore
|
||||
)
|
||||
op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken")
|
||||
op.drop_table("accesstoken")
|
||||
|
||||
@@ -19,7 +19,7 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
sequence = Sequence("connector_credential_pair_id_seq")
|
||||
op.execute(CreateSequence(sequence))
|
||||
op.execute(CreateSequence(sequence)) # type: ignore
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add failed_rules to proposal_review_run
|
||||
|
||||
Revision ID: ce2aa573d445
|
||||
Revises: 61ea78857c97
|
||||
Create Date: 2026-04-14 16:34:57.276707
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ce2aa573d445"
|
||||
down_revision = "61ea78857c97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"proposal_review_run",
|
||||
sa.Column(
|
||||
"failed_rules",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=sa.text("0"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("proposal_review_run", "failed_rules")
|
||||
@@ -1,28 +0,0 @@
|
||||
"""add_error_tracking_fields_to_index_attempt_errors
|
||||
|
||||
Revision ID: d129f37b3d87
|
||||
Revises: 503883791c39
|
||||
Create Date: 2026-04-06 19:11:18.261800
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d129f37b3d87"
|
||||
down_revision = "503883791c39"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt_errors",
|
||||
sa.Column("error_type", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("index_attempt_errors", "error_type")
|
||||
@@ -49,7 +49,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -10,7 +10,6 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
@@ -95,17 +94,6 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@before_task_publish.connect
|
||||
def on_before_task_publish(
|
||||
headers: dict[str, Any] | None = None,
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Stamp the current wall-clock time into the task message headers so that
|
||||
workers can compute queue wait time (time between publish and execution)."""
|
||||
if headers is not None:
|
||||
headers["enqueued_at"] = time.time()
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None, # noqa: ARG001
|
||||
|
||||
@@ -16,12 +16,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -42,7 +36,6 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -57,31 +50,6 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -122,7 +90,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("light")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -322,7 +322,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.server.features.proposal_review.engine",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -79,15 +79,6 @@ beat_task_templates: list[dict] = [
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-dangling-import-jobs",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DANGLING_IMPORT_JOBS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
|
||||
@@ -59,11 +59,6 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_started
|
||||
from onyx.server.metrics.deletion_metrics import observe_deletion_taskset_duration
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
@@ -107,7 +102,7 @@ def revoke_tasks_blocking_deletion(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
@@ -115,7 +110,7 @@ def revoke_tasks_blocking_deletion(
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
@@ -305,7 +300,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
inc_deletion_blocked(tenant_id, "indexing")
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -313,13 +307,11 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
inc_deletion_blocked(tenant_id, "pruning")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
inc_deletion_blocked(tenant_id, "permissions")
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -367,7 +359,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# set this only after all tasks have been added
|
||||
fence_payload.num_tasks = tasks_generated
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
inc_deletion_started(tenant_id)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
@@ -517,11 +508,7 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector:
|
||||
task_logger.info(
|
||||
"Connector deletion - Connector already deleted, skipping connector cleanup"
|
||||
)
|
||||
elif not len(connector.credentials):
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
@@ -536,12 +523,6 @@ def monitor_connector_deletion_taskset(
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "success", duration)
|
||||
inc_deletion_completed(tenant_id, "success")
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
@@ -560,11 +541,6 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
duration = (
|
||||
datetime.now(timezone.utc) - fence_data.submitted
|
||||
).total_seconds()
|
||||
observe_deletion_taskset_duration(tenant_id, "failure", duration)
|
||||
inc_deletion_completed(tenant_id, "failure")
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
@@ -741,6 +717,5 @@ def validate_connector_deletion_fence(
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
inc_deletion_fence_reset(tenant_id)
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -172,10 +172,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
task_logger.debug(
|
||||
"Verified tenant info, migration record, and search settings."
|
||||
)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
@@ -329,7 +325,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.debug("Released the OpenSearch migration lock.")
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -526,14 +525,6 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Session 1: pre-enumeration — load cc_pair and instantiate the connector.
|
||||
# The session is closed before enumeration so the DB connection is not held
|
||||
# open during the 10–30+ minute connector crawl.
|
||||
connector_source: DocumentSource | None = None
|
||||
connector_type: str = ""
|
||||
is_connector_public: bool = False
|
||||
runnable_connector: BaseConnector | None = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -559,51 +550,49 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
connector_source = cc_pair.connector.source
|
||||
connector_type = connector_source.value
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={connector_source}"
|
||||
f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
connector_source,
|
||||
cc_pair.connector.source,
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
)
|
||||
# Session 1 closed here — connection released before enumeration.
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# Extract docs and hierarchy nodes from the source (no DB session held).
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
# Extract docs and hierarchy nodes from the source
|
||||
connector_type = cc_pair.connector.source.value
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback, connector_type=connector_type
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Session 2: post-enumeration — hierarchy upserts, diff computation, task dispatch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
source = connector_source
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes: list[DBHierarchyNode] = []
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=source,
|
||||
commit=False,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
|
||||
@@ -612,13 +601,9 @@ def connector_pruning_generator_task(
|
||||
hierarchy_node_ids=[n.id for n in upserted_nodes],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
commit=False,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Single commit so the FK reference in the join table can never
|
||||
# outrun the parent hierarchy_node insert.
|
||||
db_session.commit()
|
||||
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
@@ -673,7 +658,7 @@ def connector_pruning_generator_task(
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={connector_source} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
error_type: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
@@ -39,5 +37,4 @@ class IndexAttemptErrorPydantic(BaseModel):
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
error_type=model.error_type,
|
||||
)
|
||||
|
||||
@@ -364,7 +364,7 @@ def _get_or_extract_plaintext(
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.info(f"Cache miss for file with id={file_id}")
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -633,6 +635,7 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1017,16 +1020,20 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
memory = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
persisted_memory_id = add_memory(
|
||||
memory = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -67,6 +67,7 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1005,86 +1006,93 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -449,7 +449,6 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
OPENSEARCH_MIGRATION_BEAT_LOCK = "da_lock:opensearch_migration_beat"
|
||||
CHECK_DANGLING_IMPORT_JOBS_BEAT_LOCK = "da_lock:check_dangling_import_jobs_beat"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
@@ -613,9 +612,6 @@ class OnyxCeleryTask:
|
||||
# Hook execution log retention
|
||||
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
|
||||
|
||||
# Proposal review import cleanup
|
||||
CHECK_FOR_DANGLING_IMPORT_JOBS = "check_for_dangling_import_jobs"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
@@ -61,9 +61,6 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 5
|
||||
|
||||
_SERVER_ERROR_CODES = {500, 502, 503, 504}
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
@@ -572,8 +569,7 @@ class OnyxConfluence:
|
||||
if not limit:
|
||||
limit = _DEFAULT_PAGINATION_LIMIT
|
||||
|
||||
current_limit = limit
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(current_limit))
|
||||
url_suffix = update_param_in_path(url_suffix, "limit", str(limit))
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
@@ -613,61 +609,40 @@ class OnyxConfluence:
|
||||
)
|
||||
continue
|
||||
|
||||
if raw_response.status_code in _SERVER_ERROR_CODES:
|
||||
# Try reducing the page size -- Confluence often times out
|
||||
# on large result sets (especially Cloud 504s).
|
||||
if current_limit > _MINIMUM_PAGINATION_LIMIT:
|
||||
old_limit = current_limit
|
||||
current_limit = max(
|
||||
current_limit // 2, _MINIMUM_PAGINATION_LIMIT
|
||||
)
|
||||
logger.warning(
|
||||
f"Confluence returned {raw_response.status_code}. "
|
||||
f"Reducing limit from {old_limit} to {current_limit} "
|
||||
f"and retrying."
|
||||
)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
continue
|
||||
# If we fail due to a 500, try one by one.
|
||||
# NOTE: this iterative approach only works for server, since cloud uses cursor-based
|
||||
# pagination
|
||||
if raw_response.status_code == 500 and not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
if initial_start is None:
|
||||
# can't handle this if we don't have offset-based pagination
|
||||
raise
|
||||
|
||||
# Limit reduction exhausted -- for Server, fall back to
|
||||
# one-by-one offset pagination as a last resort.
|
||||
if not self._is_cloud:
|
||||
initial_start = get_start_param_from_url(url_suffix)
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = (
|
||||
yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=current_limit,
|
||||
)
|
||||
)
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
# this will just yield the successful items from the batch
|
||||
new_url_suffix = yield from self._try_one_by_one_for_paginated_url(
|
||||
url_suffix,
|
||||
initial_start=initial_start,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
# this means we ran into an empty page
|
||||
if new_url_suffix is None:
|
||||
if next_page_callback:
|
||||
next_page_callback("")
|
||||
break
|
||||
|
||||
url_suffix = new_url_suffix
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} "
|
||||
f"after reducing limit to {current_limit}.\n"
|
||||
f"Raw Response Text: {raw_response.text}\n"
|
||||
f"Error: {e}\n"
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
@@ -705,10 +680,6 @@ class OnyxConfluence:
|
||||
old_url_suffix = url_suffix
|
||||
updated_start = get_start_param_from_url(old_url_suffix)
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
if url_suffix and current_limit != limit:
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "limit", str(current_limit)
|
||||
)
|
||||
for i, result in enumerate(results):
|
||||
updated_start += 1
|
||||
if url_suffix and next_page_callback and i == len(results) - 1:
|
||||
|
||||
@@ -42,9 +42,6 @@ from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_files_by_web_view_links_batch,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
@@ -73,13 +70,11 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import NormalizationResult
|
||||
from onyx.connectors.interfaces import Resolver
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -207,9 +202,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint],
|
||||
Resolver,
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1672,82 +1665,6 @@ class GoogleDriveConnector(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
@override
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
if self._creds is None or self._primary_admin_email is None:
|
||||
raise RuntimeError(
|
||||
"Credentials missing, should not call this method before calling load_credentials"
|
||||
)
|
||||
|
||||
logger.info(f"Resolving {len(errors)} errors")
|
||||
doc_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in errors
|
||||
if failure.failed_document
|
||||
]
|
||||
service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
batch_result = get_files_by_web_view_links_batch(service, doc_ids, field_type)
|
||||
|
||||
for doc_id, error in batch_result.errors.items():
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=doc_id,
|
||||
),
|
||||
failure_message=f"Failed to retrieve file during error resolution: {error}",
|
||||
exception=error,
|
||||
)
|
||||
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
retrieved_files = [
|
||||
RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=self.primary_admin_email,
|
||||
completion_stage=DriveRetrievalStage.DONE,
|
||||
)
|
||||
for file in batch_result.files.values()
|
||||
]
|
||||
|
||||
yield from self._get_new_ancestors_for_files(
|
||||
files=retrieved_files,
|
||||
seen_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
fully_walked_hierarchy_node_raw_ids=ThreadSafeSet(),
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(rf, permission_sync_context),
|
||||
)
|
||||
for rf in retrieved_files
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
for result in results:
|
||||
if result is not None:
|
||||
yield result
|
||||
|
||||
def _extract_slim_docs_from_google_drive(
|
||||
self,
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
|
||||
@@ -9,7 +9,6 @@ from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -61,8 +60,6 @@ SLIM_FILE_FIELDS = (
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
MAX_BATCH_SIZE = 100
|
||||
|
||||
HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId"
|
||||
|
||||
HIERARCHY_FIELDS_WITH_PERMISSIONS = (
|
||||
@@ -219,7 +216,7 @@ def get_external_access_for_folder(
|
||||
|
||||
|
||||
def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().list() based on the field type enum."""
|
||||
"""Get the appropriate fields string based on the field type enum"""
|
||||
if field_type == DriveFileFieldType.SLIM:
|
||||
return SLIM_FILE_FIELDS
|
||||
elif field_type == DriveFileFieldType.WITH_PERMISSIONS:
|
||||
@@ -228,25 +225,6 @@ def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return FILE_FIELDS
|
||||
|
||||
|
||||
def _extract_single_file_fields(list_fields: str) -> str:
|
||||
"""Convert a files().list() fields string to one suitable for files().get().
|
||||
|
||||
List fields look like "nextPageToken, files(field1, field2, ...)"
|
||||
Single-file fields should be just "field1, field2, ..."
|
||||
"""
|
||||
start = list_fields.find("files(")
|
||||
if start == -1:
|
||||
return list_fields
|
||||
inner_start = start + len("files(")
|
||||
inner_end = list_fields.rfind(")")
|
||||
return list_fields[inner_start:inner_end]
|
||||
|
||||
|
||||
def _get_single_file_fields(field_type: DriveFileFieldType) -> str:
|
||||
"""Get the appropriate fields string for files().get() based on the field type enum."""
|
||||
return _extract_single_file_fields(_get_fields_for_file_type(field_type))
|
||||
|
||||
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
@@ -558,74 +536,3 @@ def get_file_by_web_view_link(
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
class BatchRetrievalResult:
|
||||
"""Result of a batch file retrieval, separating successes from errors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.files: dict[str, GoogleDriveFileType] = {}
|
||||
self.errors: dict[str, Exception] = {}
|
||||
|
||||
|
||||
def get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
field_type: DriveFileFieldType,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Retrieve multiple Google Drive files by webViewLink using the batch API.
|
||||
|
||||
Returns a BatchRetrievalResult containing successful file retrievals
|
||||
and errors for any files that could not be fetched.
|
||||
Automatically splits into chunks of MAX_BATCH_SIZE.
|
||||
"""
|
||||
fields = _get_single_file_fields(field_type)
|
||||
if len(web_view_links) <= MAX_BATCH_SIZE:
|
||||
return _get_files_by_web_view_links_batch(service, web_view_links, fields)
|
||||
|
||||
combined = BatchRetrievalResult()
|
||||
for i in range(0, len(web_view_links), MAX_BATCH_SIZE):
|
||||
chunk = web_view_links[i : i + MAX_BATCH_SIZE]
|
||||
chunk_result = _get_files_by_web_view_links_batch(service, chunk, fields)
|
||||
combined.files.update(chunk_result.files)
|
||||
combined.errors.update(chunk_result.errors)
|
||||
return combined
|
||||
|
||||
|
||||
def _get_files_by_web_view_links_batch(
|
||||
service: GoogleDriveService,
|
||||
web_view_links: list[str],
|
||||
fields: str,
|
||||
) -> BatchRetrievalResult:
|
||||
"""Single-batch implementation."""
|
||||
|
||||
result = BatchRetrievalResult()
|
||||
|
||||
def callback(
|
||||
request_id: str,
|
||||
response: GoogleDriveFileType,
|
||||
exception: Exception | None,
|
||||
) -> None:
|
||||
if exception:
|
||||
logger.warning(f"Error retrieving file {request_id}: {exception}")
|
||||
result.errors[request_id] = exception
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
batch.add(request, request_id=web_view_link)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to extract file ID from {web_view_link}: {e}")
|
||||
result.errors[web_view_link] = e
|
||||
|
||||
batch.execute()
|
||||
return result
|
||||
|
||||
@@ -298,22 +298,6 @@ class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Resolver(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def resolve_errors(
|
||||
self,
|
||||
errors: list[ConnectorFailure],
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure | HierarchyNode, None, None]:
|
||||
"""Attempts to yield back ALL the documents described by the errors, no checkpointing.
|
||||
|
||||
Caller's responsibility is to delete the old ConnectorFailures and replace with the new ones.
|
||||
If include_permissions is True, the documents will have permissions synced.
|
||||
May also yield HierarchyNode objects for ancestor folders of resolved documents.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HierarchyConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_hierarchy(
|
||||
|
||||
@@ -8,7 +8,6 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
@@ -41,7 +40,6 @@ from onyx.connectors.jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
from onyx.connectors.jira.utils import build_jira_url
|
||||
from onyx.connectors.jira.utils import CustomFieldExtractor
|
||||
from onyx.connectors.jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.jira.utils import get_comment_strs
|
||||
from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION
|
||||
@@ -54,7 +52,6 @@ from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -67,7 +64,6 @@ _MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
_MAX_ATTACHMENT_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -381,7 +377,6 @@ def process_jira_issue(
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
custom_fields_mapping: dict[str, str] | None = None,
|
||||
) -> Document | None:
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
@@ -467,24 +462,6 @@ def process_jira_issue(
|
||||
else:
|
||||
logger.error(f"Project should exist but does not for {issue.key}")
|
||||
|
||||
# Merge custom fields into metadata if a mapping was provided
|
||||
if custom_fields_mapping:
|
||||
try:
|
||||
custom_fields = CustomFieldExtractor.get_issue_custom_fields(
|
||||
issue, custom_fields_mapping
|
||||
)
|
||||
# Filter out custom fields that collide with existing metadata keys
|
||||
for key in list(custom_fields.keys()):
|
||||
if key in metadata_dict:
|
||||
logger.warning(
|
||||
f"Custom field '{key}' on {issue.key} collides with "
|
||||
f"standard metadata key; skipping custom field value"
|
||||
)
|
||||
del custom_fields[key]
|
||||
metadata_dict.update(custom_fields)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract custom fields for {issue.key}: {e}")
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
@@ -527,12 +504,6 @@ class JiraConnector(
|
||||
# Custom JQL query to filter Jira issues
|
||||
jql_query: str | None = None,
|
||||
scoped_token: bool = False,
|
||||
# When True, extract custom fields from Jira issues and include them
|
||||
# in document metadata with human-readable field names.
|
||||
extract_custom_fields: bool = False,
|
||||
# When True, download attachments from Jira issues and yield them
|
||||
# as separate Documents linked to the parent ticket.
|
||||
fetch_attachments: bool = False,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -546,11 +517,7 @@ class JiraConnector(
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.jql_query = jql_query
|
||||
self.scoped_token = scoped_token
|
||||
self.extract_custom_fields = extract_custom_fields
|
||||
self.fetch_attachments = fetch_attachments
|
||||
self._jira_client: JIRA | None = None
|
||||
# Mapping of custom field IDs to human-readable names (populated on load_credentials)
|
||||
self._custom_fields_mapping: dict[str, str] = {}
|
||||
# Cache project permissions to avoid fetching them repeatedly across runs
|
||||
self._project_permissions_cache: dict[str, Any] = {}
|
||||
|
||||
@@ -711,134 +678,12 @@ class JiraConnector(
|
||||
# the document belongs directly under the project in the hierarchy
|
||||
return project_key
|
||||
|
||||
def _process_attachments(
|
||||
self,
|
||||
issue: Issue,
|
||||
parent_hierarchy_raw_node_id: str | None,
|
||||
include_permissions: bool = False,
|
||||
project_key: str | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Download and yield Documents for each attachment on a Jira issue.
|
||||
|
||||
Each attachment becomes a separate Document whose text is extracted
|
||||
from the downloaded file content. Failures on individual attachments
|
||||
are logged and yielded as ConnectorFailure so they never break the
|
||||
overall indexing run.
|
||||
"""
|
||||
attachments = best_effort_get_field_from_issue(issue, "attachment")
|
||||
if not attachments:
|
||||
return
|
||||
|
||||
issue_url = build_jira_url(self.jira_base, issue.key)
|
||||
|
||||
for attachment in attachments:
|
||||
try:
|
||||
filename = getattr(attachment, "filename", "unknown")
|
||||
try:
|
||||
size = int(getattr(attachment, "size", 0) or 0)
|
||||
except (ValueError, TypeError):
|
||||
size = 0
|
||||
content_url = getattr(attachment, "content", None)
|
||||
attachment_id = getattr(attachment, "id", filename)
|
||||
mime_type = getattr(attachment, "mimeType", "application/octet-stream")
|
||||
created = getattr(attachment, "created", None)
|
||||
|
||||
if size > _MAX_ATTACHMENT_SIZE_BYTES:
|
||||
logger.warning(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"size {size} bytes exceeds {_MAX_ATTACHMENT_SIZE_BYTES} byte limit"
|
||||
)
|
||||
continue
|
||||
|
||||
if not content_url:
|
||||
logger.warning(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"no content URL available"
|
||||
)
|
||||
continue
|
||||
|
||||
# Download the attachment using the public API on the
|
||||
# python-jira Attachment resource (avoids private _session access
|
||||
# and the double-copy from response.content + BytesIO wrapping).
|
||||
file_content = attachment.get()
|
||||
|
||||
# Extract text from the downloaded file
|
||||
try:
|
||||
text = extract_file_text(
|
||||
file=BytesIO(file_content),
|
||||
file_name=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not extract text from attachment '{filename}' "
|
||||
f"on {issue.key}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not text or not text.strip():
|
||||
logger.info(
|
||||
f"Skipping attachment '{filename}' on {issue.key}: "
|
||||
f"no text content could be extracted"
|
||||
)
|
||||
continue
|
||||
|
||||
doc_id = f"{issue_url}/attachments/{attachment_id}"
|
||||
attachment_doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(link=issue_url, text=text)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {filename}",
|
||||
title=filename,
|
||||
doc_updated_at=(time_str_to_utc(created) if created else None),
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
metadata={
|
||||
"parent_ticket": issue.key,
|
||||
"attachment_filename": filename,
|
||||
"attachment_mime_type": mime_type,
|
||||
"attachment_size": str(size),
|
||||
},
|
||||
)
|
||||
if include_permissions and project_key:
|
||||
attachment_doc.external_access = self._get_project_permissions(
|
||||
project_key,
|
||||
add_prefix=True,
|
||||
)
|
||||
yield attachment_doc
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process attachment on {issue.key}: {e}")
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{issue_url}/attachments/{getattr(attachment, 'id', 'unknown')}",
|
||||
document_link=issue_url,
|
||||
),
|
||||
failure_message=f"Failed to process attachment '{getattr(attachment, 'filename', 'unknown')}': {str(e)}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
|
||||
# Fetch the custom field ID-to-name mapping once at credential load time.
|
||||
# This avoids repeated API calls during issue processing.
|
||||
if self.extract_custom_fields:
|
||||
try:
|
||||
self._custom_fields_mapping = (
|
||||
CustomFieldExtractor.get_all_custom_fields(self._jira_client)
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {len(self._custom_fields_mapping)} custom field definitions"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch custom field definitions; "
|
||||
f"custom field extraction will be skipped: {e}"
|
||||
)
|
||||
self._custom_fields_mapping = {}
|
||||
|
||||
return None
|
||||
|
||||
def _get_jql_query(
|
||||
@@ -969,11 +814,6 @@ class JiraConnector(
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
custom_fields_mapping=(
|
||||
self._custom_fields_mapping
|
||||
if self._custom_fields_mapping
|
||||
else None
|
||||
),
|
||||
):
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
@@ -983,15 +823,6 @@ class JiraConnector(
|
||||
)
|
||||
yield document
|
||||
|
||||
# Yield attachment documents if enabled
|
||||
if self.fetch_attachments:
|
||||
yield from self._process_attachments(
|
||||
issue=issue,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
include_permissions=include_permissions,
|
||||
project_key=project_key,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
@@ -1099,41 +930,20 @@ class JiraConnector(
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
doc_id = build_jira_url(self.jira_base, issue_key)
|
||||
|
||||
parent_hierarchy_raw_node_id = (
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
)
|
||||
project_perms = self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
)
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=doc_id,
|
||||
# Permission sync path - don't prefix, upsert_document_external_perms handles it
|
||||
external_access=project_perms,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Also emit SlimDocument entries for each attachment
|
||||
if self.fetch_attachments:
|
||||
attachments = best_effort_get_field_from_issue(issue, "attachment")
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
attachment_id = getattr(
|
||||
attachment,
|
||||
"id",
|
||||
getattr(attachment, "filename", "unknown"),
|
||||
)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"{doc_id}/attachments/{attachment_id}",
|
||||
external_access=project_perms,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -34,17 +33,9 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
)
|
||||
|
||||
|
||||
class SectionType(str, Enum):
|
||||
"""Discriminator for Section subclasses."""
|
||||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
type: SectionType
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_id: str | None = None
|
||||
@@ -53,7 +44,6 @@ class Section(BaseModel):
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
type: Literal[SectionType.TEXT] = SectionType.TEXT
|
||||
text: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -63,7 +53,6 @@ class TextSection(Section):
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
type: Literal[SectionType.IMAGE] = SectionType.IMAGE
|
||||
image_file_id: str
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
@@ -145,6 +134,7 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo":
|
||||
|
||||
first_name = cast(str, model_dict.get("FirstName"))
|
||||
last_name = cast(str, model_dict.get("LastName"))
|
||||
email = cast(str, model_dict.get("Email"))
|
||||
|
||||
@@ -335,7 +335,6 @@ def update_document_set(
|
||||
"Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set_row.name = document_set_update_request.name
|
||||
document_set_row.description = document_set_update_request.description
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -347,25 +346,6 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -378,11 +358,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -390,11 +367,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -899,7 +899,6 @@ def create_index_attempt_error(
|
||||
failure: ConnectorFailure,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
exc = failure.exception
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
@@ -922,7 +921,6 @@ def create_index_attempt_error(
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
error_type=type(exc).__name__ if exc else None,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -84,51 +83,47 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
@@ -2422,8 +2422,6 @@ class IndexAttemptError(Base):
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
error_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
|
||||
@@ -7,6 +7,8 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -20,7 +22,6 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -183,14 +184,6 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -200,6 +193,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -723,7 +717,6 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -744,7 +737,10 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=research_agent_tool_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
@@ -14,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.indexing.chunking import DocumentChunker
|
||||
from onyx.indexing.chunking import extract_blurb
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
# actually help quality at all
|
||||
@@ -150,6 +154,9 @@ class Chunker:
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
# Create a token counter function that returns the count instead of the tokens
|
||||
def token_counter(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
@@ -179,12 +186,234 @@ class Chunker:
|
||||
else None
|
||||
)
|
||||
|
||||
self._document_chunker = DocumentChunker(
|
||||
tokenizer=tokenizer,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
chunk_splitter=self.chunk_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
"""
|
||||
Splits the text into smaller chunks based on token count to ensure
|
||||
no chunk exceeds the content_token_limit.
|
||||
"""
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
def _extract_blurb(self, text: str) -> str:
|
||||
"""
|
||||
Extract a short blurb from the text (first chunk of size `blurb_size`).
|
||||
"""
|
||||
# chunker is in `text` mode
|
||||
texts = cast(list[str], self.blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
# chunker is in `text` mode
|
||||
return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text))
|
||||
return None
|
||||
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
image_file_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to create a new DocAwareChunk, append it to chunks_list.
|
||||
"""
|
||||
new_chunk = DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks_list),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
Works with processed sections that are base Section objects.
|
||||
"""
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link_text = section.link or ""
|
||||
image_url = section.image_file_id
|
||||
|
||||
# If there is no useful content, skip
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this section has an image, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
is_continuation=False,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image section
|
||||
# (Using the text summary that was generated during processing)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text} if section_link_text else {},
|
||||
image_file_id=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
# Continue to next section
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# chunker is in `text` mode
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
small_chunk,
|
||||
{0: section_link_text},
|
||||
is_continuation=(j != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
else:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
split_text,
|
||||
{0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
link_offsets[current_offset] = section_link_text
|
||||
else:
|
||||
# finalize the existing chunk
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
# start a new chunk
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# finalize any leftover text chunk
|
||||
if chunk_text.strip() or not chunks:
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets or {0: ""}, # safe default
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(
|
||||
self, document: IndexingDocument
|
||||
@@ -194,10 +423,7 @@ class Chunker:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
|
||||
# Title prep
|
||||
title = extract_blurb(
|
||||
document.get_title_for_document_index() or "",
|
||||
self.blurb_splitter,
|
||||
)
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
|
||||
@@ -265,7 +491,7 @@ class Chunker:
|
||||
# Use processed_sections if available (IndexingDocument), otherwise use original sections
|
||||
sections_to_chunk = document.processed_sections
|
||||
|
||||
normal_chunks = self._document_chunker.chunk(
|
||||
normal_chunks = self._chunk_document_with_sections(
|
||||
document,
|
||||
sections_to_chunk,
|
||||
title_prefix,
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from onyx.indexing.chunking.document_chunker import DocumentChunker
|
||||
from onyx.indexing.chunking.section_chunker import extract_blurb
|
||||
|
||||
__all__ = [
|
||||
"DocumentChunker",
|
||||
"extract_blurb",
|
||||
]
|
||||
@@ -1,109 +0,0 @@
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SectionType
|
||||
from onyx.indexing.chunking.image_section_chunker import ImageChunker
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.text_section_chunker import TextChunker
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentChunker:
|
||||
"""Converts a document's processed sections into DocAwareChunks.
|
||||
|
||||
Drop-in replacement for `Chunker._chunk_document_with_sections`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
blurb_splitter: SentenceChunker,
|
||||
chunk_splitter: SentenceChunker,
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> None:
|
||||
self.blurb_splitter = blurb_splitter
|
||||
self.mini_chunk_splitter = mini_chunk_splitter
|
||||
|
||||
self._dispatch: dict[SectionType, SectionChunker] = {
|
||||
SectionType.TEXT: TextChunker(
|
||||
tokenizer=tokenizer,
|
||||
chunk_splitter=chunk_splitter,
|
||||
),
|
||||
SectionType.IMAGE: ImageChunker(),
|
||||
}
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
payloads = self._collect_section_payloads(
|
||||
document=document,
|
||||
sections=sections,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
if not payloads:
|
||||
payloads.append(ChunkPayload(text="", links={0: ""}))
|
||||
|
||||
return [
|
||||
payload.to_doc_aware_chunk(
|
||||
document=document,
|
||||
chunk_id=idx,
|
||||
blurb_splitter=self.blurb_splitter,
|
||||
mini_chunk_splitter=self.mini_chunk_splitter,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for idx, payload in enumerate(payloads)
|
||||
]
|
||||
|
||||
def _collect_section_payloads(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
content_token_limit: int,
|
||||
) -> list[ChunkPayload]:
|
||||
accumulator = AccumulatorState()
|
||||
payloads: list[ChunkPayload] = []
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
logger.warning(
|
||||
f"Skipping empty or irrelevant section in doc "
|
||||
f"{document.semantic_identifier}, link={section.link}"
|
||||
)
|
||||
continue
|
||||
|
||||
chunker = self._select_chunker(section)
|
||||
result = chunker.chunk_section(
|
||||
section=section,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
payloads.extend(result.payloads)
|
||||
accumulator = result.accumulator
|
||||
|
||||
payloads.extend(accumulator.flush_to_list())
|
||||
return payloads
|
||||
|
||||
def _select_chunker(self, section: Section) -> SectionChunker:
|
||||
try:
|
||||
return self._dispatch[section.type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No SectionChunker registered for type={section.type}")
|
||||
@@ -1,35 +0,0 @@
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.utils.text_processing import clean_text
|
||||
|
||||
|
||||
class ImageChunker(SectionChunker):
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int, # noqa: ARG002
|
||||
) -> SectionChunkerOutput:
|
||||
assert section.image_file_id is not None
|
||||
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
|
||||
# Flush any partially built text chunks
|
||||
payloads = accumulator.flush_to_list()
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=section_text,
|
||||
links={0: section_link} if section_link else {},
|
||||
image_file_id=section.image_file_id,
|
||||
is_continuation=False,
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
@@ -1,100 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: SentenceChunker) -> str:
|
||||
texts = cast(list[str], blurb_splitter.chunk(text))
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def get_mini_chunk_texts(
|
||||
chunk_text: str,
|
||||
mini_chunk_splitter: SentenceChunker | None,
|
||||
) -> list[str] | None:
|
||||
if mini_chunk_splitter and chunk_text.strip():
|
||||
return list(cast(Sequence[str], mini_chunk_splitter.chunk(chunk_text)))
|
||||
return None
|
||||
|
||||
|
||||
class ChunkPayload(BaseModel):
|
||||
"""Section-local chunk content without document-scoped fields.
|
||||
|
||||
The orchestrator upgrades these to DocAwareChunks via
|
||||
`to_doc_aware_chunk` after assigning chunk_ids and attaching
|
||||
title/metadata.
|
||||
"""
|
||||
|
||||
text: str
|
||||
links: dict[int, str]
|
||||
is_continuation: bool = False
|
||||
image_file_id: str | None = None
|
||||
|
||||
def to_doc_aware_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
chunk_id: int,
|
||||
blurb_splitter: SentenceChunker,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
mini_chunk_splitter: SentenceChunker | None = None,
|
||||
) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=chunk_id,
|
||||
blurb=extract_blurb(self.text, blurb_splitter),
|
||||
content=self.text,
|
||||
source_links=self.links or {0: ""},
|
||||
image_file_id=self.image_file_id,
|
||||
section_continuation=self.is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=get_mini_chunk_texts(self.text, mini_chunk_splitter),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class AccumulatorState(BaseModel):
|
||||
"""Cross-section text buffer threaded through SectionChunkers."""
|
||||
|
||||
text: str = ""
|
||||
link_offsets: dict[int, str] = Field(default_factory=dict)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.text.strip()
|
||||
|
||||
def flush_to_list(self) -> list[ChunkPayload]:
|
||||
if self.is_empty():
|
||||
return []
|
||||
return [ChunkPayload(text=self.text, links=self.link_offsets)]
|
||||
|
||||
|
||||
class SectionChunkerOutput(BaseModel):
|
||||
payloads: list[ChunkPayload]
|
||||
accumulator: AccumulatorState
|
||||
|
||||
|
||||
class SectionChunker(ABC):
|
||||
@abstractmethod
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput: ...
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from chonkie import SentenceChunker
|
||||
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.utils.text_processing import clean_text
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
|
||||
|
||||
class TextChunker(SectionChunker):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_splitter: SentenceChunker,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.chunk_splitter = chunk_splitter
|
||||
|
||||
self.section_separator_token_count = count_tokens(
|
||||
SECTION_SEPARATOR,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
def chunk_section(
|
||||
self,
|
||||
section: Section,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_link = section.link or ""
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
|
||||
# Oversized — flush buffer and split the section
|
||||
if section_token_count > content_token_limit:
|
||||
return self._handle_oversized_section(
|
||||
section_text=section_text,
|
||||
section_link=section_link,
|
||||
accumulator=accumulator,
|
||||
content_token_limit=content_token_limit,
|
||||
)
|
||||
|
||||
current_token_count = count_tokens(accumulator.text, self.tokenizer)
|
||||
next_section_tokens = self.section_separator_token_count + section_token_count
|
||||
|
||||
# Fits — extend the accumulator
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
offset = len(shared_precompare_cleanup(accumulator.text))
|
||||
new_text = accumulator.text
|
||||
if new_text:
|
||||
new_text += SECTION_SEPARATOR
|
||||
new_text += section_text
|
||||
return SectionChunkerOutput(
|
||||
payloads=[],
|
||||
accumulator=AccumulatorState(
|
||||
text=new_text,
|
||||
link_offsets={**accumulator.link_offsets, offset: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
# Doesn't fit — flush buffer and restart with this section
|
||||
return SectionChunkerOutput(
|
||||
payloads=accumulator.flush_to_list(),
|
||||
accumulator=AccumulatorState(
|
||||
text=section_text,
|
||||
link_offsets={0: section_link},
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_oversized_section(
|
||||
self,
|
||||
section_text: str,
|
||||
section_link: str,
|
||||
accumulator: AccumulatorState,
|
||||
content_token_limit: int,
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
split_texts = cast(list[str], self.chunk_splitter.chunk(section_text))
|
||||
for i, split_text in enumerate(split_texts):
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and count_tokens(split_text, self.tokenizer) > content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=small_chunk,
|
||||
links={0: section_link},
|
||||
is_continuation=(j != 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
payloads.append(
|
||||
ChunkPayload(
|
||||
text=split_text,
|
||||
links={0: section_link},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
)
|
||||
|
||||
return SectionChunkerOutput(
|
||||
payloads=payloads,
|
||||
accumulator=AccumulatorState(),
|
||||
)
|
||||
|
||||
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
total_tokens = len(tokens)
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
@@ -542,7 +542,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
**document.model_dump(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
type=section.type,
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
link=section.link,
|
||||
image_file_id=(
|
||||
@@ -567,7 +566,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved - ensure text is always a string
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
link=section.link,
|
||||
image_file_id=section.image_file_id,
|
||||
text="", # Initialize with empty string
|
||||
@@ -611,7 +609,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -96,9 +96,6 @@ from onyx.server.features.persona.api import admin_router as admin_persona_route
|
||||
from onyx.server.features.persona.api import agents_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.projects.api import router as projects_router
|
||||
from onyx.server.features.proposal_review.api.api import (
|
||||
router as proposal_review_router,
|
||||
)
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
from onyx.server.features.tool.api import router as tool_router
|
||||
from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router
|
||||
@@ -472,7 +469,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, projects_router)
|
||||
include_router_with_global_prefix_prepended(application, public_build_router)
|
||||
include_router_with_global_prefix_prepended(application, build_router)
|
||||
include_router_with_global_prefix_prepended(application, proposal_review_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, hierarchy_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.2.3",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
@@ -961,9 +961,9 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@hono/node-server": {
|
||||
"version": "1.19.13",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.13.tgz",
|
||||
"integrity": "sha512-TsQLe4i2gvoTtrHje625ngThGBySOgSK3Xo2XRYOdqGN1teR8+I7vchQC46uLJi8OF62YTYA3AhSpumtkhsaKQ==",
|
||||
"version": "1.19.10",
|
||||
"resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.10.tgz",
|
||||
"integrity": "sha512-hZ7nOssGqRgyV3FVVQdfi+U4q02uB23bpnYpdvNXkYTRRyWx84b7yf1ans+dnJ/7h41sGL3CeQTfO+ZGxuO+Iw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=18.14.1"
|
||||
@@ -1711,9 +1711,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.2.3.tgz",
|
||||
"integrity": "sha512-ZWXyj4uNu4GCWQw9cjRxWlbD+33mcDszIo9iQxFnBX3Wmgq9ulaSJcl6VhuWx5pCWqqD+9W6Wfz7N0lM5lYPMA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -1727,9 +1727,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.2.3.tgz",
|
||||
"integrity": "sha512-u37KDKTKQ+OQLvY+z7SNXixwo4Q2/IAJFDzU1fYe66IbCE51aDSAzkNDkWmLN0yjTUh4BKBd+hb69jYn6qqqSg==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1743,9 +1743,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.2.3.tgz",
|
||||
"integrity": "sha512-gHjL/qy6Q6CG3176FWbAKyKh9IfntKZTB3RY/YOJdDFpHGsUDXVH38U4mMNpHVGXmeYW4wj22dMp1lTfmu/bTQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1759,9 +1759,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-U6vtblPtU/P14Y/b/n9ZY0GOxbbIhTFuaFR7F4/uMBidCi2nSdaOFhA0Go81L61Zd6527+yvuX44T4ksnf8T+Q==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1775,9 +1775,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-/YV0LgjHUmfhQpn9bVoGc4x4nan64pkhWR5wyEV8yCOfwwrH630KpvRg86olQHTwHIn1z59uh6JwKvHq1h4QEw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1791,9 +1791,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.2.3.tgz",
|
||||
"integrity": "sha512-/HiWEcp+WMZ7VajuiMEFGZ6cg0+aYZPqCJD3YJEfpVWQsKYSjXQG06vJP6F1rdA03COD9Fef4aODs3YxKx+RDQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1807,9 +1807,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.2.3.tgz",
|
||||
"integrity": "sha512-Kt44hGJfZSefebhk/7nIdivoDr3Ugp5+oNz9VvF3GUtfxutucUIHfIO0ZYO8QlOPDQloUVQn4NVC/9JvHRk9hw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1823,9 +1823,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-O2NZ9ie3Tq6xj5Z5CSwBT3+aWAMW2PIZ4egUi9MaWLkwaehgtB7YZjPm+UpcNpKOme0IQuqDcor7BsW6QBiQBw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1839,9 +1839,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.2.3.tgz",
|
||||
"integrity": "sha512-Ibm29/GgB/ab5n7XKqlStkm54qqZE8v2FnijUPBgrd67FWrac45o/RsNlaOWjme/B5UqeWt/8KM4aWBwA1D2Kw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -7427,9 +7427,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.12",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.12.tgz",
|
||||
"integrity": "sha512-p1JfQMKaceuCbpJKAPKVqyqviZdS0eUxH9v82oWo1kb9xjQ5wA6iP3FNVAPDFlz5/p7d45lO+BpSk1tuSZMF4Q==",
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
@@ -8637,9 +8637,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.18.1",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.18.1.tgz",
|
||||
"integrity": "sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==",
|
||||
"version": "4.17.23",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz",
|
||||
"integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/lodash.merge": {
|
||||
@@ -8978,12 +8978,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.2.3",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.2.3.tgz",
|
||||
"integrity": "sha512-9V3zV4oZFza3PVev5/poB9g0dEafVcgNyQ8eTRop8GvxZjV2G15FC5ARuG1eFD42QgeYkzJBJzHghNP8Ad9xtA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.2.3",
|
||||
"@next/env": "16.1.7",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
@@ -8997,15 +8997,15 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.2.3",
|
||||
"@next/swc-darwin-x64": "16.2.3",
|
||||
"@next/swc-linux-arm64-gnu": "16.2.3",
|
||||
"@next/swc-linux-arm64-musl": "16.2.3",
|
||||
"@next/swc-linux-x64-gnu": "16.2.3",
|
||||
"@next/swc-linux-x64-musl": "16.2.3",
|
||||
"@next/swc-win32-arm64-msvc": "16.2.3",
|
||||
"@next/swc-win32-x64-msvc": "16.2.3",
|
||||
"sharp": "^0.34.5"
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@opentelemetry/api": "^1.1.0",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.2.3",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
|
||||
@@ -63,7 +63,6 @@ class DocumentSetCreationRequest(BaseModel):
|
||||
|
||||
class DocumentSetUpdateRequest(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
cc_pair_ids: list[int]
|
||||
is_public: bool
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
"""Main router for Proposal Review.
|
||||
|
||||
Mounts all sub-routers under /proposal-review prefix.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.server.features.proposal_review.configs import ENABLE_PROPOSAL_REVIEW
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/proposal-review",
|
||||
dependencies=[Depends(require_permission(Permission.BASIC_ACCESS))],
|
||||
)
|
||||
|
||||
if ENABLE_PROPOSAL_REVIEW:
|
||||
from onyx.server.features.proposal_review.api.config_api import (
|
||||
router as config_router,
|
||||
)
|
||||
from onyx.server.features.proposal_review.api.decisions_api import (
|
||||
router as decisions_router,
|
||||
)
|
||||
from onyx.server.features.proposal_review.api.proposals_api import (
|
||||
router as proposals_router,
|
||||
)
|
||||
from onyx.server.features.proposal_review.api.review_api import (
|
||||
router as review_router,
|
||||
)
|
||||
from onyx.server.features.proposal_review.api.rulesets_api import (
|
||||
router as rulesets_router,
|
||||
)
|
||||
|
||||
router.include_router(rulesets_router, tags=["proposal-review"])
|
||||
router.include_router(proposals_router, tags=["proposal-review"])
|
||||
router.include_router(review_router, tags=["proposal-review"])
|
||||
router.include_router(decisions_router, tags=["proposal-review"])
|
||||
router.include_router(config_router, tags=["proposal-review"])
|
||||
@@ -1,89 +0,0 @@
|
||||
"""API endpoints for tenant configuration."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector import fetch_connectors
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.proposal_review.api.models import ConfigResponse
|
||||
from onyx.server.features.proposal_review.api.models import ConfigUpdate
|
||||
from onyx.server.features.proposal_review.api.models import JiraConnectorInfo
|
||||
from onyx.server.features.proposal_review.db import config as config_db
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
def get_config(
|
||||
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConfigResponse:
|
||||
"""Get the tenant's proposal review configuration."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
config = config_db.get_config(tenant_id, db_session)
|
||||
if not config:
|
||||
# Return a default empty config rather than 404
|
||||
config = config_db.upsert_config(tenant_id, db_session)
|
||||
db_session.commit()
|
||||
return ConfigResponse.from_model(config)
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
def update_config(
|
||||
request: ConfigUpdate,
|
||||
_user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConfigResponse:
|
||||
"""Update the tenant's proposal review configuration."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
config = config_db.upsert_config(
|
||||
tenant_id=tenant_id,
|
||||
jira_connector_id=request.jira_connector_id,
|
||||
jira_project_key=request.jira_project_key,
|
||||
field_mapping=request.field_mapping,
|
||||
jira_writeback=request.jira_writeback,
|
||||
review_model=request.review_model,
|
||||
import_model=request.import_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return ConfigResponse.from_model(config)
|
||||
|
||||
|
||||
@router.get("/jira-connectors")
|
||||
def list_jira_connectors(
|
||||
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[JiraConnectorInfo]:
|
||||
"""List all Jira connectors available to this tenant."""
|
||||
connectors = fetch_connectors(db_session, sources=[DocumentSource.JIRA])
|
||||
results: list[JiraConnectorInfo] = []
|
||||
for c in connectors:
|
||||
cfg = c.connector_specific_config or {}
|
||||
project_key = cfg.get("project_key", "")
|
||||
base_url = cfg.get("jira_base_url", "")
|
||||
results.append(
|
||||
JiraConnectorInfo(
|
||||
id=c.id,
|
||||
name=c.name,
|
||||
project_key=project_key,
|
||||
project_url=base_url,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@router.get("/jira-connectors/{connector_id}/metadata-keys")
|
||||
def get_connector_metadata_keys(
|
||||
connector_id: int,
|
||||
_user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
"""Return the distinct doc_metadata keys across all documents for a connector."""
|
||||
return config_db.get_connector_metadata_keys(connector_id, db_session)
|
||||
@@ -1,147 +0,0 @@
|
||||
"""API endpoints for per-finding decisions, proposal decisions, and Jira sync."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.proposal_review.api.models import FindingDecisionCreate
|
||||
from onyx.server.features.proposal_review.api.models import FindingResponse
|
||||
from onyx.server.features.proposal_review.api.models import JiraSyncResponse
|
||||
from onyx.server.features.proposal_review.api.models import ProposalDecisionCreate
|
||||
from onyx.server.features.proposal_review.api.models import ProposalDecisionResponse
|
||||
from onyx.server.features.proposal_review.db import decisions as decisions_db
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
from onyx.server.features.proposal_review.db import proposals as proposals_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/findings/{finding_id}/decision",
|
||||
)
|
||||
def record_finding_decision(
|
||||
finding_id: UUID,
|
||||
request: FindingDecisionCreate,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FindingResponse:
|
||||
"""Record or update a decision on a finding (upsert)."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Verify finding exists
|
||||
finding = findings_db.get_finding(finding_id, db_session)
|
||||
if not finding:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Finding not found")
|
||||
|
||||
# Verify the finding's proposal belongs to the current tenant
|
||||
proposal = proposals_db.get_proposal(finding.proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Finding not found")
|
||||
|
||||
finding = decisions_db.upsert_finding_decision(
|
||||
finding_id=finding_id,
|
||||
officer_id=user.id,
|
||||
action=request.action,
|
||||
notes=request.notes,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return FindingResponse.from_model(finding)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/proposals/{proposal_id}/decision",
|
||||
status_code=201,
|
||||
)
|
||||
def record_proposal_decision(
|
||||
proposal_id: UUID,
|
||||
request: ProposalDecisionCreate,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ProposalDecisionResponse:
|
||||
"""Record a final decision on a proposal."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
# Validate decision value
|
||||
valid_decisions = {"APPROVED", "CHANGES_REQUESTED", "REJECTED"}
|
||||
if request.decision not in valid_decisions:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"decision must be APPROVED, CHANGES_REQUESTED, or REJECTED",
|
||||
)
|
||||
|
||||
proposal = decisions_db.update_proposal_decision(
|
||||
proposal_id=proposal_id,
|
||||
tenant_id=tenant_id,
|
||||
officer_id=user.id,
|
||||
decision=request.decision,
|
||||
notes=request.notes,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return ProposalDecisionResponse.from_proposal(proposal)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/proposals/{proposal_id}/sync-jira",
|
||||
)
|
||||
def sync_jira(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JiraSyncResponse:
|
||||
"""Sync the latest proposal decision to Jira.
|
||||
|
||||
Dispatches a Celery task that performs 3 Jira API operations:
|
||||
1. Update custom fields (decision, completion %)
|
||||
2. Transition the issue to the appropriate column
|
||||
3. Post a structured review summary comment
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
if not proposal.decision_at:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"No decision to sync -- record a proposal decision first",
|
||||
)
|
||||
|
||||
if proposal.jira_synced:
|
||||
return JiraSyncResponse(
|
||||
success=True,
|
||||
message="Decision already synced to Jira",
|
||||
)
|
||||
|
||||
# Dispatch Celery task via the client app (has Redis broker configured)
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
|
||||
celery_app.send_task(
|
||||
"sync_decision_to_jira",
|
||||
args=[str(proposal_id), tenant_id],
|
||||
expires=300,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return JiraSyncResponse(
|
||||
success=True,
|
||||
message="Jira sync task dispatched",
|
||||
)
|
||||
@@ -1,483 +0,0 @@
|
||||
"""Pydantic request/response models for Proposal Review."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewConfig
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewImportJob
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRuleset
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Ruleset Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RulesetCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
is_default: bool = False
|
||||
|
||||
|
||||
class RulesetUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
is_default: bool | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class RulesetResponse(BaseModel):
|
||||
id: UUID
|
||||
tenant_id: str
|
||||
name: str
|
||||
description: str | None
|
||||
is_default: bool
|
||||
is_active: bool
|
||||
created_by: UUID | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
rules: list["RuleResponse"] = []
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
ruleset: ProposalReviewRuleset,
|
||||
include_rules: bool = True,
|
||||
) -> "RulesetResponse":
|
||||
return cls(
|
||||
id=ruleset.id,
|
||||
tenant_id=ruleset.tenant_id,
|
||||
name=ruleset.name,
|
||||
description=ruleset.description,
|
||||
is_default=ruleset.is_default,
|
||||
is_active=ruleset.is_active,
|
||||
created_by=ruleset.created_by,
|
||||
created_at=ruleset.created_at,
|
||||
updated_at=ruleset.updated_at,
|
||||
rules=(
|
||||
[RuleResponse.from_model(r) for r in ruleset.rules]
|
||||
if include_rules
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rule Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RuleCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
category: str | None = None
|
||||
rule_type: Literal[
|
||||
"DOCUMENT_CHECK", "METADATA_CHECK", "CROSS_REFERENCE", "CUSTOM_NL"
|
||||
]
|
||||
rule_intent: Literal["CHECK", "HIGHLIGHT"] = "CHECK"
|
||||
prompt_template: str
|
||||
source: Literal["IMPORTED", "MANUAL"] = "MANUAL"
|
||||
authority: Literal["OVERRIDE", "RETURN"] | None = None
|
||||
is_hard_stop: bool = False
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class RuleUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
category: str | None = None
|
||||
rule_type: str | None = None
|
||||
rule_intent: str | None = None
|
||||
prompt_template: str | None = None
|
||||
authority: str | None = None
|
||||
is_hard_stop: bool | None = None
|
||||
priority: int | None = None
|
||||
is_active: bool | None = None
|
||||
refinement_needed: bool | None = None
|
||||
refinement_question: str | None = None
|
||||
|
||||
|
||||
class RuleRefinementRequest(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
class RuleResponse(BaseModel):
|
||||
id: UUID
|
||||
ruleset_id: UUID
|
||||
name: str
|
||||
description: str | None
|
||||
category: str | None
|
||||
rule_type: str
|
||||
rule_intent: str
|
||||
prompt_template: str
|
||||
source: str
|
||||
authority: str | None
|
||||
is_hard_stop: bool
|
||||
priority: int
|
||||
is_active: bool
|
||||
refinement_needed: bool
|
||||
refinement_question: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, rule: ProposalReviewRule) -> "RuleResponse":
|
||||
return cls(
|
||||
id=rule.id,
|
||||
ruleset_id=rule.ruleset_id,
|
||||
name=rule.name,
|
||||
description=rule.description,
|
||||
category=rule.category,
|
||||
rule_type=rule.rule_type,
|
||||
rule_intent=rule.rule_intent,
|
||||
prompt_template=rule.prompt_template,
|
||||
source=rule.source,
|
||||
authority=rule.authority,
|
||||
is_hard_stop=rule.is_hard_stop,
|
||||
priority=rule.priority,
|
||||
is_active=rule.is_active,
|
||||
refinement_needed=rule.refinement_needed,
|
||||
refinement_question=rule.refinement_question,
|
||||
created_at=rule.created_at,
|
||||
updated_at=rule.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class BulkRuleUpdateRequest(BaseModel):
|
||||
"""Batch activate/deactivate/delete rules."""
|
||||
|
||||
action: Literal["activate", "deactivate", "delete"]
|
||||
rule_ids: list[UUID]
|
||||
|
||||
|
||||
class BulkRuleUpdateResponse(BaseModel):
|
||||
updated_count: int
|
||||
|
||||
|
||||
class RuleTestResponse(BaseModel):
|
||||
rule_id: str
|
||||
success: bool
|
||||
error: str | None = None
|
||||
result: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Proposal Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProposalResponse(BaseModel):
|
||||
"""Proposal response including inline decision fields."""
|
||||
|
||||
id: UUID
|
||||
document_id: str
|
||||
tenant_id: str
|
||||
status: str
|
||||
# Inline decision fields
|
||||
decision_notes: str | None = None
|
||||
decision_officer_id: UUID | None = None
|
||||
decision_at: datetime | None = None
|
||||
jira_synced: bool = False
|
||||
jira_synced_at: datetime | None = None
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
# Resolved metadata from Document table via field_mapping
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
proposal: ProposalReviewProposal,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> "ProposalResponse":
|
||||
return cls(
|
||||
id=proposal.id,
|
||||
document_id=proposal.document_id,
|
||||
tenant_id=proposal.tenant_id,
|
||||
status=proposal.status,
|
||||
decision_notes=proposal.decision_notes,
|
||||
decision_officer_id=proposal.decision_officer_id,
|
||||
decision_at=proposal.decision_at,
|
||||
jira_synced=proposal.jira_synced,
|
||||
jira_synced_at=proposal.jira_synced_at,
|
||||
created_at=proposal.created_at,
|
||||
updated_at=proposal.updated_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
class ProposalListResponse(BaseModel):
|
||||
proposals: list[ProposalResponse]
|
||||
total_count: int
|
||||
config_missing: bool = False # True when no config exists
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Review Run Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ReviewRunTriggerRequest(BaseModel):
|
||||
ruleset_id: UUID
|
||||
|
||||
|
||||
class ReviewRunResponse(BaseModel):
|
||||
id: UUID
|
||||
proposal_id: UUID
|
||||
ruleset_id: UUID
|
||||
triggered_by: UUID
|
||||
status: str
|
||||
total_rules: int
|
||||
completed_rules: int
|
||||
failed_rules: int
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, run: ProposalReviewRun) -> "ReviewRunResponse":
|
||||
return cls(
|
||||
id=run.id,
|
||||
proposal_id=run.proposal_id,
|
||||
ruleset_id=run.ruleset_id,
|
||||
triggered_by=run.triggered_by,
|
||||
status=run.status,
|
||||
total_rules=run.total_rules,
|
||||
completed_rules=run.completed_rules,
|
||||
failed_rules=run.failed_rules,
|
||||
started_at=run.started_at,
|
||||
completed_at=run.completed_at,
|
||||
created_at=run.created_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Finding Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FindingResponse(BaseModel):
|
||||
id: UUID
|
||||
proposal_id: UUID
|
||||
rule_id: UUID
|
||||
review_run_id: UUID
|
||||
verdict: str
|
||||
confidence: str | None
|
||||
evidence: str | None
|
||||
explanation: str | None
|
||||
suggested_action: str | None
|
||||
llm_model: str | None
|
||||
llm_tokens_used: int | None
|
||||
created_at: datetime
|
||||
# Nested rule info for display
|
||||
rule_name: str | None = None
|
||||
rule_category: str | None = None
|
||||
rule_is_hard_stop: bool | None = None
|
||||
# Inline decision fields
|
||||
decision_action: str | None = None
|
||||
decision_notes: str | None = None
|
||||
decided_at: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, finding: ProposalReviewFinding) -> "FindingResponse":
|
||||
rule_name = None
|
||||
rule_category = None
|
||||
rule_is_hard_stop = None
|
||||
if finding.rule is not None:
|
||||
rule_name = finding.rule.name
|
||||
rule_category = finding.rule.category
|
||||
rule_is_hard_stop = finding.rule.is_hard_stop
|
||||
|
||||
return cls(
|
||||
id=finding.id,
|
||||
proposal_id=finding.proposal_id,
|
||||
rule_id=finding.rule_id,
|
||||
review_run_id=finding.review_run_id,
|
||||
verdict=finding.verdict,
|
||||
confidence=finding.confidence,
|
||||
evidence=finding.evidence,
|
||||
explanation=finding.explanation,
|
||||
suggested_action=finding.suggested_action,
|
||||
llm_model=finding.llm_model,
|
||||
llm_tokens_used=finding.llm_tokens_used,
|
||||
created_at=finding.created_at,
|
||||
rule_name=rule_name,
|
||||
rule_category=rule_category,
|
||||
rule_is_hard_stop=rule_is_hard_stop,
|
||||
decision_action=finding.decision_action,
|
||||
decision_notes=finding.decision_notes,
|
||||
decided_at=finding.decided_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Decision Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FindingDecisionCreate(BaseModel):
|
||||
action: Literal["VERIFIED", "ISSUE", "NOT_APPLICABLE", "OVERRIDDEN"]
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class ProposalDecisionCreate(BaseModel):
|
||||
decision: Literal["APPROVED", "CHANGES_REQUESTED", "REJECTED"]
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class ProposalDecisionResponse(BaseModel):
|
||||
"""Response after recording a proposal-level decision."""
|
||||
|
||||
proposal_id: UUID
|
||||
status: str
|
||||
decision_notes: str | None
|
||||
jira_synced: bool
|
||||
decision_at: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_proposal(
|
||||
cls, proposal: ProposalReviewProposal
|
||||
) -> "ProposalDecisionResponse":
|
||||
return cls(
|
||||
proposal_id=proposal.id,
|
||||
status=proposal.status,
|
||||
decision_notes=proposal.decision_notes,
|
||||
jira_synced=proposal.jira_synced,
|
||||
decision_at=proposal.decision_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
jira_connector_id: int | None = None
|
||||
jira_project_key: str | None = None
|
||||
field_mapping: list[str] | None = None # List of visible metadata keys
|
||||
jira_writeback: dict[str, Any] | None = None
|
||||
# LLM configuration
|
||||
review_model: str | None = None # model name for rule evaluation
|
||||
import_model: str | None = None # model name for checklist import
|
||||
|
||||
|
||||
class ConfigResponse(BaseModel):
|
||||
id: UUID
|
||||
tenant_id: str
|
||||
jira_connector_id: int | None
|
||||
jira_project_key: str | None
|
||||
field_mapping: list[str] | None
|
||||
jira_writeback: dict[str, Any] | None
|
||||
review_model: str | None
|
||||
import_model: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, config: ProposalReviewConfig) -> "ConfigResponse":
|
||||
return cls(
|
||||
id=config.id,
|
||||
tenant_id=config.tenant_id,
|
||||
jira_connector_id=config.jira_connector_id,
|
||||
jira_project_key=config.jira_project_key,
|
||||
field_mapping=config.field_mapping,
|
||||
jira_writeback=config.jira_writeback,
|
||||
review_model=config.review_model,
|
||||
import_model=config.import_model,
|
||||
created_at=config.created_at,
|
||||
updated_at=config.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Import Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ImportResponse(BaseModel):
|
||||
rules_created: int
|
||||
rules: list[RuleResponse]
|
||||
|
||||
|
||||
class ImportJobResponse(BaseModel):
|
||||
id: UUID
|
||||
status: str
|
||||
source_filename: str
|
||||
rules_created: int
|
||||
error_message: str | None
|
||||
created_at: datetime
|
||||
completed_at: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, job: ProposalReviewImportJob) -> "ImportJobResponse":
|
||||
return cls(
|
||||
id=job.id,
|
||||
status=job.status,
|
||||
source_filename=job.source_filename,
|
||||
rules_created=job.rules_created,
|
||||
error_message=job.error_message,
|
||||
created_at=job.created_at,
|
||||
completed_at=job.completed_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProposalDocumentResponse(BaseModel):
|
||||
id: UUID
|
||||
proposal_id: UUID
|
||||
file_name: str
|
||||
file_type: str | None
|
||||
document_role: str
|
||||
uploaded_by: UUID | None
|
||||
extracted_text: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, doc: ProposalReviewDocument) -> "ProposalDocumentResponse":
|
||||
return cls(
|
||||
id=doc.id,
|
||||
proposal_id=doc.proposal_id,
|
||||
file_name=doc.file_name,
|
||||
file_type=doc.file_type,
|
||||
document_role=doc.document_role,
|
||||
uploaded_by=doc.uploaded_by,
|
||||
extracted_text=getattr(doc, "extracted_text", None),
|
||||
created_at=doc.created_at,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Jira Sync Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class JiraSyncResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Jira Connector Discovery Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class JiraConnectorInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
project_key: str
|
||||
project_url: str
|
||||
@@ -1,367 +0,0 @@
|
||||
"""API endpoints for proposals and proposal documents."""
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Form
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.server.features.proposal_review.api.models import ProposalDocumentResponse
|
||||
from onyx.server.features.proposal_review.api.models import ProposalListResponse
|
||||
from onyx.server.features.proposal_review.api.models import ProposalResponse
|
||||
from onyx.server.features.proposal_review.configs import (
|
||||
DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES,
|
||||
)
|
||||
from onyx.server.features.proposal_review.db import config as config_db
|
||||
from onyx.server.features.proposal_review.db import proposals as proposals_db
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_document_metadata(
|
||||
document: Document,
|
||||
visible_fields: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve metadata from a Document's tags, filtered to visible fields.
|
||||
|
||||
Jira custom fields are stored as Tag rows (tag_key / tag_value)
|
||||
linked to the document via document__tag. visible_fields selects
|
||||
which tag keys to include. If None/empty, returns all tags.
|
||||
"""
|
||||
# Build metadata from the document's tags
|
||||
raw_metadata: dict[str, Any] = {}
|
||||
for tag in document.tags:
|
||||
key = tag.tag_key
|
||||
value = tag.tag_value
|
||||
# Tags with is_list=True can have multiple values for the same key
|
||||
if tag.is_list:
|
||||
raw_metadata.setdefault(key, [])
|
||||
raw_metadata[key].append(value)
|
||||
else:
|
||||
raw_metadata[key] = value
|
||||
|
||||
# Extract jira_key from tags and clean title from semantic_id.
|
||||
# Jira semantic_id is "KEY-123: Summary Text" — split to isolate each.
|
||||
jira_key = raw_metadata.get("key", "")
|
||||
title = document.semantic_id or ""
|
||||
if title and ": " in title:
|
||||
title = title.split(": ", 1)[1]
|
||||
|
||||
raw_metadata["jira_key"] = jira_key
|
||||
raw_metadata["title"] = title
|
||||
raw_metadata["link"] = document.link
|
||||
|
||||
if not visible_fields:
|
||||
return raw_metadata
|
||||
|
||||
# Filter to only the selected fields, plus always include core fields
|
||||
# that the frontend needs for navigation, display, and filtering.
|
||||
resolved: dict[str, Any] = {
|
||||
"jira_key": raw_metadata.get("jira_key"),
|
||||
"title": raw_metadata.get("title"),
|
||||
"link": raw_metadata.get("link"),
|
||||
"status": raw_metadata.get("status"),
|
||||
}
|
||||
for key in visible_fields:
|
||||
if key in raw_metadata:
|
||||
resolved[key] = raw_metadata[key]
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
@router.get("/proposals")
|
||||
def list_proposals(
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ProposalListResponse:
|
||||
"""List proposals.
|
||||
|
||||
This queries the Document table filtered by the configured Jira project,
|
||||
LEFT JOINs proposal_review_proposal for review state, and resolves
|
||||
metadata field names via the field_mapping config.
|
||||
|
||||
Documents without a proposal record are returned with status PENDING
|
||||
without persisting any new rows (read-only endpoint).
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Get config for field mapping and Jira project filtering
|
||||
config = config_db.get_config(tenant_id, db_session)
|
||||
|
||||
# When no config exists, return an empty list with a hint for the frontend.
|
||||
# The frontend can show "Configure a Jira connector in Settings to see proposals."
|
||||
if config is None:
|
||||
return ProposalListResponse(
|
||||
proposals=[],
|
||||
total_count=0,
|
||||
config_missing=True,
|
||||
)
|
||||
|
||||
visible_fields = config.field_mapping
|
||||
|
||||
# Query documents from the configured Jira connector only,
|
||||
# LEFT JOIN proposal state for review tracking.
|
||||
# NOTE: Tenant isolation is handled at the schema level (schema-per-tenant).
|
||||
# The DB session is already scoped to the current tenant's schema, so
|
||||
# cross-tenant data leakage is prevented by the connection itself.
|
||||
query = (
|
||||
db_session.query(Document, ProposalReviewProposal)
|
||||
.outerjoin(
|
||||
ProposalReviewProposal,
|
||||
Document.id == ProposalReviewProposal.document_id,
|
||||
)
|
||||
.options(selectinload(Document.tags))
|
||||
)
|
||||
|
||||
# Filter to only documents from the configured Jira connector
|
||||
if config and config.jira_connector_id:
|
||||
# Join through DocumentByConnectorCredentialPair to filter by connector
|
||||
query = query.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
).filter(
|
||||
DocumentByConnectorCredentialPair.connector_id == config.jira_connector_id,
|
||||
)
|
||||
else:
|
||||
# No connector configured — filter to Jira source connectors only
|
||||
# to avoid showing Slack/GitHub/etc documents
|
||||
query = (
|
||||
query.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
DocumentByConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(
|
||||
Connector.source == DocumentSource.JIRA,
|
||||
)
|
||||
)
|
||||
|
||||
# Exclude attachment documents — they are children of issue documents
|
||||
# and have "/attachments/" in their document ID.
|
||||
query = query.filter(~Document.id.contains("/attachments/"))
|
||||
|
||||
# If status filter is specified, only show documents with matching proposal status.
|
||||
# PENDING is special: documents without a proposal record are implicitly pending.
|
||||
if status:
|
||||
if status == "PENDING":
|
||||
query = query.filter(
|
||||
or_(
|
||||
ProposalReviewProposal.status == status,
|
||||
ProposalReviewProposal.id.is_(None),
|
||||
),
|
||||
)
|
||||
else:
|
||||
query = query.filter(ProposalReviewProposal.status == status)
|
||||
|
||||
# Count before adding DISTINCT ON — count(distinct(...)) handles
|
||||
# deduplication on its own and conflicts with DISTINCT ON.
|
||||
total_count = (
|
||||
query.with_entities(func.count(func.distinct(Document.id))).scalar() or 0
|
||||
)
|
||||
|
||||
# Deduplicate rows that can arise from multiple connector-credential pairs.
|
||||
# Applied after counting to avoid the DISTINCT ON + aggregate conflict.
|
||||
# ORDER BY Document.id is required for DISTINCT ON to be deterministic.
|
||||
query = query.distinct(Document.id).order_by(Document.id)
|
||||
results = query.offset(offset).limit(limit).all()
|
||||
|
||||
proposals: list[ProposalResponse] = []
|
||||
created_any = False
|
||||
for document, proposal in results:
|
||||
if proposal is None:
|
||||
# Lazily create the proposal record so the frontend gets a
|
||||
# stable UUID it can use for navigation and subsequent API calls.
|
||||
proposal = proposals_db.get_or_create_proposal(
|
||||
document_id=document.id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
created_any = True
|
||||
metadata = _resolve_document_metadata(document, visible_fields)
|
||||
proposals.append(ProposalResponse.from_model(proposal, metadata=metadata))
|
||||
|
||||
if created_any:
|
||||
db_session.commit()
|
||||
|
||||
return ProposalListResponse(proposals=proposals, total_count=total_count)
|
||||
|
||||
|
||||
@router.get("/proposals/{proposal_id}")
|
||||
def get_proposal(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ProposalResponse:
|
||||
"""Get a single proposal with its metadata from the Document table."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
# Load the linked Document for metadata
|
||||
document = (
|
||||
db_session.query(Document)
|
||||
.options(selectinload(Document.tags))
|
||||
.filter(Document.id == proposal.document_id)
|
||||
.one_or_none()
|
||||
)
|
||||
config = config_db.get_config(tenant_id, db_session)
|
||||
visible_fields = config.field_mapping if config else None
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
if document:
|
||||
metadata = _resolve_document_metadata(document, visible_fields)
|
||||
|
||||
return ProposalResponse.from_model(proposal, metadata=metadata)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Proposal Documents (manual uploads)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/proposals/{proposal_id}/documents",
|
||||
status_code=201,
|
||||
)
|
||||
def upload_document(
|
||||
proposal_id: UUID,
|
||||
file: UploadFile,
|
||||
document_role: str = Form("OTHER"),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ProposalDocumentResponse:
|
||||
"""Upload a document to a proposal."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
# Read file content
|
||||
try:
|
||||
file_bytes = file.file.read()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Failed to read uploaded file: {str(e)}",
|
||||
)
|
||||
|
||||
# Validate file size
|
||||
if len(file_bytes) > DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"File size {len(file_bytes)} bytes exceeds maximum "
|
||||
f"allowed size of {DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES} bytes",
|
||||
)
|
||||
|
||||
# Determine file type from filename
|
||||
filename = file.filename or "untitled"
|
||||
file_type = None
|
||||
if filename:
|
||||
parts = filename.rsplit(".", 1)
|
||||
if len(parts) > 1:
|
||||
file_type = parts[1].upper()
|
||||
|
||||
# Extract text from the uploaded file
|
||||
extracted_text = None
|
||||
if file_bytes:
|
||||
try:
|
||||
extracted_text = extract_file_text(
|
||||
file=io.BytesIO(file_bytes),
|
||||
file_name=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to extract text from uploaded file '{filename}': {e}"
|
||||
)
|
||||
|
||||
doc = ProposalReviewDocument(
|
||||
proposal_id=proposal_id,
|
||||
file_name=filename,
|
||||
file_type=file_type,
|
||||
document_role=document_role,
|
||||
uploaded_by=user.id,
|
||||
extracted_text=extracted_text,
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
return ProposalDocumentResponse.from_model(doc)
|
||||
|
||||
|
||||
@router.get("/proposals/{proposal_id}/documents")
|
||||
def list_documents(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ProposalDocumentResponse]:
|
||||
"""List documents for a proposal."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
docs = (
|
||||
db_session.query(ProposalReviewDocument)
|
||||
.filter(ProposalReviewDocument.proposal_id == proposal_id)
|
||||
.order_by(ProposalReviewDocument.created_at)
|
||||
.all()
|
||||
)
|
||||
return [ProposalDocumentResponse.from_model(d) for d in docs]
|
||||
|
||||
|
||||
@router.delete("/proposals/{proposal_id}/documents/{doc_id}", status_code=204)
|
||||
def delete_document(
|
||||
proposal_id: UUID,
|
||||
doc_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a manually uploaded document."""
|
||||
# Verify the proposal belongs to the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
doc = (
|
||||
db_session.query(ProposalReviewDocument)
|
||||
.filter(
|
||||
ProposalReviewDocument.id == doc_id,
|
||||
ProposalReviewDocument.proposal_id == proposal_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not doc:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Document not found")
|
||||
db_session.delete(doc)
|
||||
db_session.commit()
|
||||
@@ -1,211 +0,0 @@
|
||||
"""API endpoints for review triggers, status, and findings."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.proposal_review.api.models import FindingResponse
|
||||
from onyx.server.features.proposal_review.api.models import ReviewRunResponse
|
||||
from onyx.server.features.proposal_review.api.models import ReviewRunTriggerRequest
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
from onyx.server.features.proposal_review.db import proposals as proposals_db
|
||||
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/proposals/{proposal_id}/review",
|
||||
status_code=201,
|
||||
)
|
||||
def trigger_review(
|
||||
proposal_id: UUID,
|
||||
request: ReviewRunTriggerRequest,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ReviewRunResponse:
|
||||
"""Trigger a new review run for a proposal."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Verify proposal exists
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
# Verify ruleset exists and count active rules
|
||||
ruleset = rulesets_db.get_ruleset(request.ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
active_rule_count = rulesets_db.count_active_rules(request.ruleset_id, db_session)
|
||||
if active_rule_count == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Ruleset has no active rules",
|
||||
)
|
||||
|
||||
# Update proposal status to IN_REVIEW
|
||||
proposals_db.update_proposal_status(proposal_id, tenant_id, "IN_REVIEW", db_session)
|
||||
|
||||
# Create the review run record
|
||||
run = findings_db.create_review_run(
|
||||
proposal_id=proposal_id,
|
||||
ruleset_id=request.ruleset_id,
|
||||
triggered_by=user.id,
|
||||
total_rules=active_rule_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Review triggered for proposal {proposal_id} "
|
||||
f"with ruleset {request.ruleset_id} ({active_rule_count} rules)"
|
||||
)
|
||||
|
||||
# Dispatch Celery task via the client app (has Redis broker configured)
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
|
||||
celery_app.send_task(
|
||||
"run_proposal_review",
|
||||
args=[str(run.id), tenant_id],
|
||||
expires=3600,
|
||||
)
|
||||
|
||||
return ReviewRunResponse.from_model(run)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/proposals/{proposal_id}/review-runs",
|
||||
)
|
||||
def list_review_runs(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ReviewRunResponse]:
|
||||
"""List all review runs for a proposal, most recent first."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
runs = findings_db.list_review_runs_by_proposal(proposal_id, db_session)
|
||||
return [ReviewRunResponse.from_model(r) for r in runs]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/proposals/{proposal_id}/review-status",
|
||||
)
|
||||
def get_review_status(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ReviewRunResponse:
|
||||
"""Get the status of the latest review run for a proposal."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
run = findings_db.get_latest_review_run(proposal_id, db_session)
|
||||
if not run:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No review runs found")
|
||||
|
||||
return ReviewRunResponse.from_model(run)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/proposals/{proposal_id}/retry-failed",
|
||||
status_code=200,
|
||||
)
|
||||
def retry_failed_rules_endpoint(
|
||||
proposal_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ReviewRunResponse:
|
||||
"""Retry only the rules that failed in the latest review run."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
run = findings_db.get_latest_review_run(proposal_id, db_session)
|
||||
if not run:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No review runs found")
|
||||
|
||||
if run.status not in ("COMPLETED", "FAILED"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Cannot retry: review is still running",
|
||||
)
|
||||
|
||||
failed = findings_db.get_failed_findings_for_run(run.id, db_session)
|
||||
if not failed:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"No failed rules to retry",
|
||||
)
|
||||
|
||||
rule_ids = list({str(f.rule_id) for f in failed})
|
||||
|
||||
# Set status to RUNNING before dispatching so a second call is rejected
|
||||
run.status = "RUNNING"
|
||||
run.completed_at = None
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Retrying {len(rule_ids)} failed rules for run {run.id} "
|
||||
f"on proposal {proposal_id}"
|
||||
)
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
|
||||
celery_app.send_task(
|
||||
"run_proposal_review",
|
||||
args=[str(run.id), tenant_id],
|
||||
kwargs={"rule_ids": rule_ids},
|
||||
expires=3600,
|
||||
)
|
||||
|
||||
return ReviewRunResponse.from_model(run)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/proposals/{proposal_id}/findings",
|
||||
)
|
||||
def get_findings(
|
||||
proposal_id: UUID,
|
||||
review_run_id: UUID | None = None,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FindingResponse]:
|
||||
"""Get findings for a proposal.
|
||||
|
||||
If review_run_id is not specified, returns findings from the latest run.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Proposal not found")
|
||||
|
||||
# If no run specified, get the latest
|
||||
if review_run_id is None:
|
||||
run = findings_db.get_latest_review_run(proposal_id, db_session)
|
||||
if not run:
|
||||
return []
|
||||
review_run_id = run.id
|
||||
|
||||
results = findings_db.list_findings_by_run(review_run_id, db_session)
|
||||
return [FindingResponse.from_model(f) for f in results]
|
||||
@@ -1,545 +0,0 @@
|
||||
"""API endpoints for rulesets and rules."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Form
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.proposal_review.api.models import BulkRuleUpdateRequest
|
||||
from onyx.server.features.proposal_review.api.models import BulkRuleUpdateResponse
|
||||
from onyx.server.features.proposal_review.api.models import ImportJobResponse
|
||||
from onyx.server.features.proposal_review.api.models import RuleCreate
|
||||
from onyx.server.features.proposal_review.api.models import RuleResponse
|
||||
from onyx.server.features.proposal_review.api.models import RulesetCreate
|
||||
from onyx.server.features.proposal_review.api.models import RulesetResponse
|
||||
from onyx.server.features.proposal_review.api.models import RulesetUpdate
|
||||
from onyx.server.features.proposal_review.api.models import RuleTestResponse
|
||||
from onyx.server.features.proposal_review.api.models import RuleUpdate
|
||||
from onyx.server.features.proposal_review.configs import IMPORT_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.proposal_review.db import imports as imports_db
|
||||
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rulesets
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/rulesets")
|
||||
def list_rulesets(
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RulesetResponse]:
|
||||
"""List all rulesets for the current tenant."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
rulesets = rulesets_db.list_rulesets(tenant_id, db_session)
|
||||
return [RulesetResponse.from_model(rs) for rs in rulesets]
|
||||
|
||||
|
||||
@router.post("/rulesets", status_code=201)
|
||||
def create_ruleset(
|
||||
request: RulesetCreate,
|
||||
user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS)),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RulesetResponse:
|
||||
"""Create a new ruleset."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.create_ruleset(
|
||||
tenant_id=tenant_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
is_default=request.is_default,
|
||||
created_by=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return RulesetResponse.from_model(ruleset, include_rules=False)
|
||||
|
||||
|
||||
@router.get("/rulesets/{ruleset_id}")
|
||||
def get_ruleset(
|
||||
ruleset_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RulesetResponse:
|
||||
"""Get a ruleset with all its rules."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
return RulesetResponse.from_model(ruleset)
|
||||
|
||||
|
||||
@router.put("/rulesets/{ruleset_id}")
|
||||
def update_ruleset(
|
||||
ruleset_id: UUID,
|
||||
request: RulesetUpdate,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RulesetResponse:
|
||||
"""Update a ruleset."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.update_ruleset(
|
||||
ruleset_id=ruleset_id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
updates=request.model_dump(exclude_unset=True),
|
||||
)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
db_session.commit()
|
||||
return RulesetResponse.from_model(ruleset)
|
||||
|
||||
|
||||
@router.delete("/rulesets/{ruleset_id}", status_code=204)
|
||||
def delete_ruleset(
|
||||
ruleset_id: UUID,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a ruleset and all its rules."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
deleted = rulesets_db.delete_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rules
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rulesets/{ruleset_id}/rules",
|
||||
status_code=201,
|
||||
)
|
||||
def create_rule(
|
||||
ruleset_id: UUID,
|
||||
request: RuleCreate,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RuleResponse:
|
||||
"""Create a new rule within a ruleset."""
|
||||
# Verify ruleset exists and belongs to tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
rule = rulesets_db.create_rule(
|
||||
ruleset_id=ruleset_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
category=request.category,
|
||||
rule_type=request.rule_type,
|
||||
rule_intent=request.rule_intent,
|
||||
prompt_template=request.prompt_template,
|
||||
source=request.source,
|
||||
authority=request.authority,
|
||||
is_hard_stop=request.is_hard_stop,
|
||||
priority=request.priority,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return RuleResponse.from_model(rule)
|
||||
|
||||
|
||||
@router.put("/rules/{rule_id}")
|
||||
def update_rule(
|
||||
rule_id: UUID,
|
||||
request: RuleUpdate,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RuleResponse:
|
||||
"""Update a rule."""
|
||||
# Verify the rule belongs to a ruleset owned by the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
|
||||
if not rule:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
|
||||
updated_rule = rulesets_db.update_rule(
|
||||
rule_id=rule_id,
|
||||
db_session=db_session,
|
||||
updates=request.model_dump(exclude_unset=True),
|
||||
)
|
||||
if not updated_rule:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
db_session.commit()
|
||||
return RuleResponse.from_model(updated_rule)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", status_code=204)
|
||||
def delete_rule(
|
||||
rule_id: UUID,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a rule."""
|
||||
# Verify the rule belongs to a ruleset owned by the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
|
||||
if not rule:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
|
||||
deleted = rulesets_db.delete_rule(rule_id, db_session)
|
||||
if not deleted:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rulesets/{ruleset_id}/rules/bulk-update",
|
||||
)
|
||||
def bulk_update_rules(
|
||||
ruleset_id: UUID,
|
||||
request: BulkRuleUpdateRequest,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BulkRuleUpdateResponse:
|
||||
"""Batch activate/deactivate/delete rules."""
|
||||
# Verify the ruleset belongs to the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
if request.action not in ("activate", "deactivate", "delete"):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"action must be 'activate', 'deactivate', or 'delete'",
|
||||
)
|
||||
# Only operate on rules that belong to this ruleset (tenant-scoped)
|
||||
count = rulesets_db.bulk_update_rules(
|
||||
request.rule_ids, request.action, ruleset_id, db_session
|
||||
)
|
||||
db_session.commit()
|
||||
return BulkRuleUpdateResponse(updated_count=count)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rulesets/{ruleset_id}/import",
|
||||
status_code=202,
|
||||
)
|
||||
def import_checklist_endpoint(
|
||||
ruleset_id: UUID,
|
||||
file: UploadFile,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Upload a checklist document and parse it into rules via LLM.
|
||||
|
||||
Text extraction happens synchronously (fast). The LLM decomposition
|
||||
runs in a Celery task so the request returns 202 immediately.
|
||||
Poll GET /rulesets/{ruleset_id}/import/{import_job_id}/status
|
||||
to track progress.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
# Read the uploaded file content (synchronous -- fast)
|
||||
try:
|
||||
file_content = file.file.read()
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Failed to read uploaded file: {str(e)}",
|
||||
)
|
||||
|
||||
if not file_content:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Uploaded file is empty")
|
||||
|
||||
if len(file_content) > IMPORT_MAX_FILE_SIZE_BYTES:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"File size {len(file_content)} bytes exceeds maximum "
|
||||
f"allowed size of {IMPORT_MAX_FILE_SIZE_BYTES} bytes",
|
||||
)
|
||||
|
||||
# Extract text synchronously (fast -- no LLM involved)
|
||||
extracted_text = ""
|
||||
filename = file.filename or "untitled"
|
||||
file_ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
if file_ext in ("txt", "text", "md"):
|
||||
extracted_text = file_content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
try:
|
||||
import io
|
||||
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
file=io.BytesIO(file_content),
|
||||
file_name=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Failed to extract text from file: {str(e)}",
|
||||
)
|
||||
|
||||
if not extracted_text or not extracted_text.strip():
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"No text could be extracted from the uploaded file",
|
||||
)
|
||||
|
||||
# Create the import job row
|
||||
job = imports_db.create_import_job(
|
||||
ruleset_id=ruleset_id,
|
||||
tenant_id=tenant_id,
|
||||
source_filename=filename,
|
||||
extracted_text=extracted_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Dispatch Celery task via the client app (has Redis broker configured)
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
|
||||
celery_app.send_task(
|
||||
"run_checklist_import",
|
||||
args=[str(job.id), tenant_id],
|
||||
expires=600,
|
||||
)
|
||||
|
||||
return {"import_job_id": str(job.id)}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rulesets/{ruleset_id}/import/active",
|
||||
)
|
||||
def get_active_import_job(
|
||||
ruleset_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ImportJobResponse | None:
|
||||
"""Get the latest active (PENDING/RUNNING) import job for a ruleset."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
job = imports_db.get_active_import_job(ruleset_id, db_session)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
return ImportJobResponse.from_model(job)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/rulesets/{ruleset_id}/import/{import_job_id}/status",
|
||||
)
|
||||
def get_import_job_status(
|
||||
ruleset_id: UUID,
|
||||
import_job_id: UUID,
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)), # noqa: ARG001
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ImportJobResponse:
|
||||
"""Get the status of a checklist import job."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Verify ruleset belongs to tenant
|
||||
ruleset = rulesets_db.get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Ruleset not found")
|
||||
|
||||
job = imports_db.get_import_job(import_job_id, db_session)
|
||||
if not job or job.ruleset_id != ruleset_id:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Import job not found")
|
||||
|
||||
return ImportJobResponse.from_model(job)
|
||||
|
||||
|
||||
@router.post("/rules/{rule_id}/test")
|
||||
def test_rule(
|
||||
rule_id: UUID,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RuleTestResponse:
|
||||
"""Test a rule against sample text.
|
||||
|
||||
Evaluates the rule against an empty/minimal proposal context to verify
|
||||
the prompt template is well-formed and the LLM can produce a valid response.
|
||||
"""
|
||||
# Verify the rule belongs to a ruleset owned by the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
|
||||
if not rule:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
|
||||
from onyx.server.features.proposal_review.engine.context_assembler import (
|
||||
ProposalContext,
|
||||
)
|
||||
from onyx.server.features.proposal_review.engine.rule_evaluator import (
|
||||
evaluate_rule,
|
||||
)
|
||||
|
||||
# Build a minimal test context
|
||||
test_context = ProposalContext(
|
||||
proposal_text="[Sample proposal text for testing. No real proposal loaded.]",
|
||||
budget_text="[No budget text available for test.]",
|
||||
foa_text="[No FOA text available for test.]",
|
||||
metadata={"test_mode": True},
|
||||
jira_key="TEST-000",
|
||||
)
|
||||
|
||||
try:
|
||||
result = evaluate_rule(rule, test_context, db_session)
|
||||
except Exception as e:
|
||||
return RuleTestResponse(
|
||||
rule_id=str(rule_id),
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return RuleTestResponse(
|
||||
rule_id=str(rule_id),
|
||||
success=True,
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/rules/{rule_id}/refine")
|
||||
async def refine_rule_endpoint(
|
||||
rule_id: UUID,
|
||||
answer: str = Form(...),
|
||||
file: UploadFile | None = None,
|
||||
user: User = Depends( # noqa: ARG001
|
||||
require_permission(Permission.MANAGE_CONNECTORS)
|
||||
),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RuleResponse:
|
||||
"""Submit an answer to a rule's refinement question.
|
||||
|
||||
Re-runs the LLM to produce a refined prompt_template that incorporates
|
||||
the user's institution-specific information, then clears the refinement
|
||||
flag on the rule. An optional file attachment (pdf, docx, etc.) can be
|
||||
included — its extracted text is appended to the answer.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
rule = rulesets_db.get_rule_with_tenant_check(rule_id, tenant_id, db_session)
|
||||
if not rule:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
|
||||
if not rule.refinement_needed:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"This rule does not need refinement",
|
||||
)
|
||||
|
||||
if not rule.refinement_question:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Rule is marked for refinement but has no refinement question",
|
||||
)
|
||||
|
||||
# Build the combined answer from the text field + optional attachment
|
||||
combined_answer = answer.strip()
|
||||
|
||||
if file and file.filename:
|
||||
import io
|
||||
|
||||
file_content = await file.read()
|
||||
if file_content:
|
||||
filename = file.filename
|
||||
file_ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
if file_ext in ("txt", "text", "md"):
|
||||
file_text = file_content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
try:
|
||||
from onyx.file_processing.extract_file_text import (
|
||||
extract_file_text,
|
||||
)
|
||||
|
||||
file_text = extract_file_text(
|
||||
file=io.BytesIO(file_content),
|
||||
file_name=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Failed to extract text from file: {str(e)}",
|
||||
)
|
||||
|
||||
if file_text and file_text.strip():
|
||||
combined_answer += (
|
||||
f"\n\n--- Attached file: {filename} ---\n{file_text.strip()}"
|
||||
)
|
||||
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.features.proposal_review.engine.checklist_importer import (
|
||||
refine_rule,
|
||||
)
|
||||
|
||||
llm = get_default_llm(timeout=120)
|
||||
|
||||
try:
|
||||
refined = refine_rule(
|
||||
rule_name=rule.name,
|
||||
rule_description=rule.description,
|
||||
rule_prompt_template=rule.prompt_template,
|
||||
refinement_question=rule.refinement_question,
|
||||
user_answer=combined_answer,
|
||||
llm=llm,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVER_ERROR,
|
||||
f"Refinement failed: {str(e)}",
|
||||
)
|
||||
|
||||
updated = rulesets_db.update_rule(
|
||||
rule_id=rule_id,
|
||||
db_session=db_session,
|
||||
updates={
|
||||
"name": refined["name"],
|
||||
"description": refined.get("description"),
|
||||
"prompt_template": refined["prompt_template"],
|
||||
"rule_type": refined.get("rule_type", rule.rule_type),
|
||||
"rule_intent": refined.get("rule_intent", rule.rule_intent),
|
||||
"refinement_needed": False,
|
||||
"refinement_question": None,
|
||||
},
|
||||
)
|
||||
if not updated:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Rule not found")
|
||||
db_session.commit()
|
||||
return RuleResponse.from_model(updated)
|
||||
@@ -1,18 +0,0 @@
|
||||
import os
|
||||
|
||||
# Feature flag for enabling proposal review
|
||||
ENABLE_PROPOSAL_REVIEW = (
|
||||
os.environ.get("ENABLE_PROPOSAL_REVIEW", "true").lower() == "true"
|
||||
)
|
||||
|
||||
# Maximum file size for checklist imports (in MB)
|
||||
IMPORT_MAX_FILE_SIZE_MB = int(
|
||||
os.environ.get("PROPOSAL_REVIEW_IMPORT_MAX_FILE_SIZE_MB", "50")
|
||||
)
|
||||
IMPORT_MAX_FILE_SIZE_BYTES = IMPORT_MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
|
||||
# Maximum file size for document uploads (in MB)
|
||||
DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB = int(
|
||||
os.environ.get("PROPOSAL_REVIEW_DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB", "100")
|
||||
)
|
||||
DOCUMENT_UPLOAD_MAX_FILE_SIZE_BYTES = DOCUMENT_UPLOAD_MAX_FILE_SIZE_MB * 1024 * 1024
|
||||
@@ -1,101 +0,0 @@
|
||||
"""DB operations for tenant configuration."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Document__Tag
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import Tag
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_config(
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewConfig | None:
|
||||
"""Get the config row for a tenant (there is at most one)."""
|
||||
return (
|
||||
db_session.query(ProposalReviewConfig)
|
||||
.filter(ProposalReviewConfig.tenant_id == tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def upsert_config(
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
jira_connector_id: int | None = None,
|
||||
jira_project_key: str | None = None,
|
||||
field_mapping: list[str] | None = None,
|
||||
jira_writeback: dict[str, Any] | None = None,
|
||||
review_model: str | None = None,
|
||||
import_model: str | None = None,
|
||||
) -> ProposalReviewConfig:
|
||||
"""Create or update the tenant config."""
|
||||
config = get_config(tenant_id, db_session)
|
||||
|
||||
if config:
|
||||
if jira_connector_id is not None:
|
||||
config.jira_connector_id = jira_connector_id
|
||||
if jira_project_key is not None:
|
||||
config.jira_project_key = jira_project_key
|
||||
if field_mapping is not None:
|
||||
config.field_mapping = field_mapping
|
||||
if jira_writeback is not None:
|
||||
config.jira_writeback = jira_writeback
|
||||
if review_model is not None:
|
||||
config.review_model = review_model
|
||||
if import_model is not None:
|
||||
config.import_model = import_model
|
||||
config.updated_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
logger.info(f"Updated proposal review config for tenant {tenant_id}")
|
||||
return config
|
||||
|
||||
config = ProposalReviewConfig(
|
||||
tenant_id=tenant_id,
|
||||
jira_connector_id=jira_connector_id,
|
||||
jira_project_key=jira_project_key,
|
||||
field_mapping=field_mapping,
|
||||
jira_writeback=jira_writeback,
|
||||
review_model=review_model,
|
||||
import_model=import_model,
|
||||
)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
logger.info(f"Created proposal review config for tenant {tenant_id}")
|
||||
return config
|
||||
|
||||
|
||||
def get_connector_metadata_keys(
|
||||
connector_id: int,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""Return distinct metadata tag keys for documents from a connector.
|
||||
|
||||
Jira custom fields are stored as tags (tag_key / tag_value) linked
|
||||
to documents via the document__tag join table.
|
||||
"""
|
||||
stmt = (
|
||||
select(Tag.tag_key)
|
||||
.select_from(Tag)
|
||||
.join(Document__Tag, Tag.id == Document__Tag.tag_id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document__Tag.document_id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
)
|
||||
.distinct()
|
||||
.limit(500)
|
||||
)
|
||||
rows = db_session.execute(stmt).all()
|
||||
return sorted(row[0] for row in rows)
|
||||
@@ -1,115 +0,0 @@
|
||||
"""DB operations for finding decisions and proposal decisions.
|
||||
|
||||
Finding decisions are stored inline on the ProposalReviewFinding row.
|
||||
Proposal decisions are stored inline on the ProposalReviewProposal row.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Per-Finding Decisions (inline on finding row)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def upsert_finding_decision(
|
||||
finding_id: UUID,
|
||||
officer_id: UUID,
|
||||
action: str,
|
||||
db_session: Session,
|
||||
notes: str | None = None,
|
||||
) -> ProposalReviewFinding:
|
||||
"""Record or update a decision on a finding.
|
||||
|
||||
The decision fields live directly on the finding row.
|
||||
"""
|
||||
finding = (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(ProposalReviewFinding.id == finding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if not finding:
|
||||
raise ValueError(f"Finding {finding_id} not found")
|
||||
|
||||
finding.decision_action = action
|
||||
finding.decision_notes = notes
|
||||
finding.decision_officer_id = officer_id
|
||||
finding.decided_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
|
||||
logger.info(f"Recorded decision on finding {finding_id}: {action}")
|
||||
return finding
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Proposal-Level Decisions (inline on proposal row)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def update_proposal_decision(
|
||||
proposal_id: UUID,
|
||||
tenant_id: str,
|
||||
officer_id: UUID,
|
||||
decision: str,
|
||||
db_session: Session,
|
||||
notes: str | None = None,
|
||||
) -> ProposalReviewProposal:
|
||||
"""Record a final decision on a proposal.
|
||||
|
||||
Overwrites previous decision fields on the proposal row.
|
||||
"""
|
||||
proposal = (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(
|
||||
ProposalReviewProposal.id == proposal_id,
|
||||
ProposalReviewProposal.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not proposal:
|
||||
raise ValueError(f"Proposal {proposal_id} not found")
|
||||
|
||||
proposal.status = decision
|
||||
proposal.decision_notes = notes
|
||||
proposal.decision_officer_id = officer_id
|
||||
proposal.decision_at = datetime.now(timezone.utc)
|
||||
proposal.jira_synced = False
|
||||
proposal.jira_synced_at = None
|
||||
proposal.updated_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
|
||||
logger.info(f"Recorded proposal decision {decision} for proposal {proposal_id}")
|
||||
return proposal
|
||||
|
||||
|
||||
def mark_proposal_jira_synced(
|
||||
proposal_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewProposal | None:
|
||||
"""Mark a proposal's decision as synced to Jira."""
|
||||
proposal = (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(
|
||||
ProposalReviewProposal.id == proposal_id,
|
||||
ProposalReviewProposal.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not proposal:
|
||||
return None
|
||||
proposal.jira_synced = True
|
||||
proposal.jira_synced_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
logger.info(f"Marked proposal {proposal_id} as jira_synced")
|
||||
return proposal
|
||||
@@ -1,206 +0,0 @@
|
||||
"""DB operations for review runs and findings."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Review Runs
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_review_run(
|
||||
proposal_id: UUID,
|
||||
ruleset_id: UUID,
|
||||
triggered_by: UUID,
|
||||
total_rules: int,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRun:
|
||||
"""Create a new review run record."""
|
||||
run = ProposalReviewRun(
|
||||
proposal_id=proposal_id,
|
||||
ruleset_id=ruleset_id,
|
||||
triggered_by=triggered_by,
|
||||
total_rules=total_rules,
|
||||
)
|
||||
db_session.add(run)
|
||||
db_session.flush()
|
||||
logger.info(
|
||||
f"Created review run {run.id} for proposal {proposal_id} "
|
||||
f"with {total_rules} rules"
|
||||
)
|
||||
return run
|
||||
|
||||
|
||||
def get_review_run(
|
||||
run_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRun | None:
|
||||
"""Get a review run by ID."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRun)
|
||||
.filter(ProposalReviewRun.id == run_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_latest_review_run(
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRun | None:
|
||||
"""Get the most recent review run for a proposal."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRun)
|
||||
.filter(ProposalReviewRun.proposal_id == proposal_id)
|
||||
.order_by(desc(ProposalReviewRun.created_at))
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def list_review_runs_by_proposal(
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
limit: int = 20,
|
||||
) -> list[ProposalReviewRun]:
|
||||
"""List review runs for a proposal, most recent first."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRun)
|
||||
.filter(ProposalReviewRun.proposal_id == proposal_id)
|
||||
.order_by(desc(ProposalReviewRun.created_at))
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Findings
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_finding(
|
||||
proposal_id: UUID,
|
||||
rule_id: UUID,
|
||||
review_run_id: UUID,
|
||||
verdict: str,
|
||||
db_session: Session,
|
||||
confidence: str | None = None,
|
||||
evidence: str | None = None,
|
||||
explanation: str | None = None,
|
||||
suggested_action: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
llm_tokens_used: int | None = None,
|
||||
) -> ProposalReviewFinding:
|
||||
"""Create a new finding."""
|
||||
finding = ProposalReviewFinding(
|
||||
proposal_id=proposal_id,
|
||||
rule_id=rule_id,
|
||||
review_run_id=review_run_id,
|
||||
verdict=verdict,
|
||||
confidence=confidence,
|
||||
evidence=evidence,
|
||||
explanation=explanation,
|
||||
suggested_action=suggested_action,
|
||||
llm_model=llm_model,
|
||||
llm_tokens_used=llm_tokens_used,
|
||||
)
|
||||
db_session.add(finding)
|
||||
db_session.flush()
|
||||
logger.info(
|
||||
f"Created finding {finding.id} verdict={verdict} for proposal {proposal_id}"
|
||||
)
|
||||
return finding
|
||||
|
||||
|
||||
def get_finding(
|
||||
finding_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewFinding | None:
|
||||
"""Get a finding by ID with its rule eagerly loaded."""
|
||||
return (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(ProposalReviewFinding.id == finding_id)
|
||||
.options(
|
||||
selectinload(ProposalReviewFinding.rule),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def list_findings_by_proposal(
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
review_run_id: UUID | None = None,
|
||||
) -> list[ProposalReviewFinding]:
|
||||
"""List findings for a proposal, optionally filtered to a specific run."""
|
||||
query = (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(ProposalReviewFinding.proposal_id == proposal_id)
|
||||
.options(
|
||||
selectinload(ProposalReviewFinding.rule),
|
||||
)
|
||||
.order_by(ProposalReviewFinding.created_at)
|
||||
)
|
||||
if review_run_id:
|
||||
query = query.filter(ProposalReviewFinding.review_run_id == review_run_id)
|
||||
return query.all()
|
||||
|
||||
|
||||
def list_findings_by_run(
|
||||
review_run_id: UUID,
|
||||
db_session: Session,
|
||||
) -> list[ProposalReviewFinding]:
|
||||
"""List all findings for a specific review run."""
|
||||
return (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(ProposalReviewFinding.review_run_id == review_run_id)
|
||||
.options(
|
||||
selectinload(ProposalReviewFinding.rule),
|
||||
)
|
||||
.order_by(ProposalReviewFinding.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_failed_findings_for_run(
|
||||
review_run_id: UUID,
|
||||
db_session: Session,
|
||||
) -> list[ProposalReviewFinding]:
|
||||
"""Get findings that failed due to system errors (LLM timeout, etc.).
|
||||
|
||||
Error findings are created by _save_error_finding and are identifiable
|
||||
by having no LLM metadata (the call never completed successfully).
|
||||
"""
|
||||
return (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(
|
||||
ProposalReviewFinding.review_run_id == review_run_id,
|
||||
ProposalReviewFinding.verdict == "NEEDS_REVIEW",
|
||||
ProposalReviewFinding.llm_model.is_(None),
|
||||
ProposalReviewFinding.llm_tokens_used.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_findings(
|
||||
finding_ids: list[UUID],
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Delete findings by ID. Returns the number deleted."""
|
||||
if not finding_ids:
|
||||
return 0
|
||||
count = (
|
||||
db_session.query(ProposalReviewFinding)
|
||||
.filter(ProposalReviewFinding.id.in_(finding_ids))
|
||||
.delete(synchronize_session="fetch")
|
||||
)
|
||||
return count
|
||||
@@ -1,97 +0,0 @@
|
||||
"""DB operations for checklist import jobs."""
|
||||
|
||||
import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewImportJob
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_import_job(
|
||||
ruleset_id: UUID,
|
||||
tenant_id: str,
|
||||
source_filename: str,
|
||||
extracted_text: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewImportJob:
|
||||
"""Create a new import job record."""
|
||||
job = ProposalReviewImportJob(
|
||||
ruleset_id=ruleset_id,
|
||||
tenant_id=tenant_id,
|
||||
source_filename=source_filename,
|
||||
extracted_text=extracted_text,
|
||||
)
|
||||
db_session.add(job)
|
||||
db_session.flush()
|
||||
logger.info(
|
||||
f"Created import job {job.id} for ruleset {ruleset_id} "
|
||||
f"(file: {source_filename})"
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
def get_import_job(
|
||||
job_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewImportJob | None:
|
||||
"""Get a single import job by ID."""
|
||||
return (
|
||||
db_session.query(ProposalReviewImportJob)
|
||||
.filter(ProposalReviewImportJob.id == job_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_active_import_job(
|
||||
ruleset_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewImportJob | None:
|
||||
"""Get the latest PENDING or RUNNING import job for a ruleset, if any."""
|
||||
return (
|
||||
db_session.query(ProposalReviewImportJob)
|
||||
.filter(
|
||||
ProposalReviewImportJob.ruleset_id == ruleset_id,
|
||||
ProposalReviewImportJob.status.in_(["PENDING", "RUNNING"]),
|
||||
)
|
||||
.order_by(ProposalReviewImportJob.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_dangling_import_jobs(
|
||||
db_session: Session,
|
||||
stale_threshold_minutes: int = 30,
|
||||
) -> list[ProposalReviewImportJob]:
|
||||
"""Return import jobs stuck in PENDING or RUNNING for longer than the threshold."""
|
||||
cutoff = datetime.datetime.now(timezone.utc) - datetime.timedelta(
|
||||
minutes=stale_threshold_minutes
|
||||
)
|
||||
return (
|
||||
db_session.query(ProposalReviewImportJob)
|
||||
.filter(
|
||||
ProposalReviewImportJob.status.in_(["PENDING", "RUNNING"]),
|
||||
ProposalReviewImportJob.created_at < cutoff,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def mark_import_job_failed(
|
||||
job: ProposalReviewImportJob,
|
||||
error_message: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Mark an import job as FAILED with the given error message.
|
||||
|
||||
Flushes but does NOT commit — the caller is responsible for committing
|
||||
so that batch operations can be done in a single transaction.
|
||||
"""
|
||||
job.status = "FAILED"
|
||||
job.error_message = error_message
|
||||
job.completed_at = datetime.datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
@@ -1,367 +0,0 @@
|
||||
"""SQLAlchemy models for Proposal Review."""
|
||||
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PGJSONB
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from onyx.db.models import Base
|
||||
|
||||
|
||||
class ProposalReviewRuleset(Base):
|
||||
__tablename__ = "proposal_review_ruleset"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(Text, nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_default: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("false")
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("true")
|
||||
)
|
||||
created_by: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
rules: Mapped[list["ProposalReviewRule"]] = relationship(
|
||||
"ProposalReviewRule",
|
||||
back_populates="ruleset",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="ProposalReviewRule.priority",
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewRule(Base):
|
||||
__tablename__ = "proposal_review_rule"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
ruleset_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_ruleset.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
category: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
rule_type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
rule_intent: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, server_default=text("'CHECK'")
|
||||
)
|
||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
source: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, server_default=text("'MANUAL'")
|
||||
)
|
||||
authority: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_hard_stop: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("false")
|
||||
)
|
||||
priority: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("true")
|
||||
)
|
||||
refinement_needed: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("false")
|
||||
)
|
||||
refinement_question: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
ruleset: Mapped["ProposalReviewRuleset"] = relationship(
|
||||
"ProposalReviewRuleset", back_populates="rules"
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewProposal(Base):
|
||||
__tablename__ = "proposal_review_proposal"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("document_id", "tenant_id"),
|
||||
Index("ix_proposal_review_proposal_tenant_id", "tenant_id"),
|
||||
Index("ix_proposal_review_proposal_document_id", "document_id"),
|
||||
Index("ix_proposal_review_proposal_status", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
document_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, server_default=text("'PENDING'")
|
||||
)
|
||||
|
||||
# Inline proposal-level decision fields
|
||||
decision_notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
decision_officer_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
|
||||
)
|
||||
decision_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
jira_synced: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("false")
|
||||
)
|
||||
jira_synced_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
review_runs: Mapped[list["ProposalReviewRun"]] = relationship(
|
||||
"ProposalReviewRun",
|
||||
back_populates="proposal",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
findings: Mapped[list["ProposalReviewFinding"]] = relationship(
|
||||
"ProposalReviewFinding",
|
||||
back_populates="proposal",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
documents: Mapped[list["ProposalReviewDocument"]] = relationship(
|
||||
"ProposalReviewDocument",
|
||||
back_populates="proposal",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewRun(Base):
|
||||
__tablename__ = "proposal_review_run"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
proposal_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
ruleset_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_ruleset.id"),
|
||||
nullable=False,
|
||||
)
|
||||
triggered_by: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, server_default=text("'PENDING'")
|
||||
)
|
||||
total_rules: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
completed_rules: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
failed_rules: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
started_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
completed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
proposal: Mapped["ProposalReviewProposal"] = relationship(
|
||||
"ProposalReviewProposal", back_populates="review_runs"
|
||||
)
|
||||
findings: Mapped[list["ProposalReviewFinding"]] = relationship(
|
||||
"ProposalReviewFinding",
|
||||
back_populates="review_run",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewFinding(Base):
|
||||
__tablename__ = "proposal_review_finding"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
proposal_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
rule_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_rule.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
review_run_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_run.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
verdict: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
confidence: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
evidence: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
explanation: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
suggested_action: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
llm_model: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
llm_tokens_used: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Inline per-finding decision fields
|
||||
decision_action: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
decision_notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
decision_officer_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
|
||||
)
|
||||
decided_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
proposal: Mapped["ProposalReviewProposal"] = relationship(
|
||||
"ProposalReviewProposal", back_populates="findings"
|
||||
)
|
||||
review_run: Mapped["ProposalReviewRun"] = relationship(
|
||||
"ProposalReviewRun", back_populates="findings"
|
||||
)
|
||||
rule: Mapped["ProposalReviewRule"] = relationship("ProposalReviewRule")
|
||||
|
||||
|
||||
class ProposalReviewDocument(Base):
|
||||
"""Manually uploaded documents or auto-fetched FOAs."""
|
||||
|
||||
__tablename__ = "proposal_review_document"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
proposal_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_proposal.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
file_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
file_type: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
file_store_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
extracted_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
document_role: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
uploaded_by: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
proposal: Mapped["ProposalReviewProposal"] = relationship(
|
||||
"ProposalReviewProposal", back_populates="documents"
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewImportJob(Base):
|
||||
"""Tracks background checklist import jobs dispatched via Celery."""
|
||||
|
||||
__tablename__ = "proposal_review_import_job"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
ruleset_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("proposal_review_ruleset.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, server_default=text("'PENDING'")
|
||||
)
|
||||
source_filename: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
extracted_text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
rules_created: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, server_default=text("0")
|
||||
)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
completed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class ProposalReviewConfig(Base):
|
||||
"""Admin configuration (one row per tenant)."""
|
||||
|
||||
__tablename__ = "proposal_review_config"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=text("gen_random_uuid()"),
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
jira_connector_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
jira_project_key: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
field_mapping: Mapped[list | None] = mapped_column(PGJSONB(), nullable=True)
|
||||
jira_writeback: Mapped[dict | None] = mapped_column(PGJSONB(), nullable=True)
|
||||
review_model: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
import_model: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
"""DB operations for proposal state records."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_proposal(
|
||||
proposal_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewProposal | None:
|
||||
"""Get a proposal by its ID."""
|
||||
return (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(
|
||||
ProposalReviewProposal.id == proposal_id,
|
||||
ProposalReviewProposal.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_proposal_by_document_id(
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewProposal | None:
|
||||
"""Get a proposal by its linked document ID."""
|
||||
return (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(
|
||||
ProposalReviewProposal.document_id == document_id,
|
||||
ProposalReviewProposal.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_or_create_proposal(
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewProposal:
|
||||
"""Get or lazily create a proposal state record for a document.
|
||||
|
||||
This is the primary entry point — the proposal record is created on first
|
||||
interaction, not when the Jira ticket is ingested.
|
||||
"""
|
||||
proposal = get_proposal_by_document_id(document_id, tenant_id, db_session)
|
||||
if proposal:
|
||||
return proposal
|
||||
|
||||
proposal = ProposalReviewProposal(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
db_session.add(proposal)
|
||||
try:
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
proposal = get_proposal_by_document_id(document_id, tenant_id, db_session)
|
||||
if proposal is None:
|
||||
raise
|
||||
return proposal
|
||||
logger.info(f"Lazily created proposal {proposal.id} for document {document_id}")
|
||||
return proposal
|
||||
|
||||
|
||||
def list_proposals(
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ProposalReviewProposal]:
|
||||
"""List proposals for a tenant with optional status filter."""
|
||||
query = (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(ProposalReviewProposal.tenant_id == tenant_id)
|
||||
.order_by(desc(ProposalReviewProposal.updated_at))
|
||||
)
|
||||
if status:
|
||||
query = query.filter(ProposalReviewProposal.status == status)
|
||||
return query.offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
def count_proposals(
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
status: str | None = None,
|
||||
) -> int:
|
||||
"""Count proposals for a tenant."""
|
||||
query = db_session.query(ProposalReviewProposal).filter(
|
||||
ProposalReviewProposal.tenant_id == tenant_id
|
||||
)
|
||||
if status:
|
||||
query = query.filter(ProposalReviewProposal.status == status)
|
||||
return query.count()
|
||||
|
||||
|
||||
def update_proposal_status(
|
||||
proposal_id: UUID,
|
||||
tenant_id: str,
|
||||
status: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewProposal | None:
|
||||
"""Update a proposal's status."""
|
||||
proposal = (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(
|
||||
ProposalReviewProposal.id == proposal_id,
|
||||
ProposalReviewProposal.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not proposal:
|
||||
return None
|
||||
proposal.status = status
|
||||
proposal.updated_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
logger.info(f"Updated proposal {proposal_id} status to {status}")
|
||||
return proposal
|
||||
@@ -1,337 +0,0 @@
|
||||
"""DB operations for rulesets and rules."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRuleset
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RULESET_UPDATABLE_FIELDS = frozenset(
|
||||
{"name", "description", "is_default", "is_active"}
|
||||
)
|
||||
_RULE_UPDATABLE_FIELDS = frozenset(
|
||||
{
|
||||
"name",
|
||||
"description",
|
||||
"category",
|
||||
"rule_type",
|
||||
"rule_intent",
|
||||
"prompt_template",
|
||||
"authority",
|
||||
"is_hard_stop",
|
||||
"priority",
|
||||
"is_active",
|
||||
"refinement_needed",
|
||||
"refinement_question",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Ruleset CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def list_rulesets(
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
active_only: bool = False,
|
||||
) -> list[ProposalReviewRuleset]:
|
||||
"""List all rulesets for a tenant."""
|
||||
query = (
|
||||
db_session.query(ProposalReviewRuleset)
|
||||
.filter(ProposalReviewRuleset.tenant_id == tenant_id)
|
||||
.options(selectinload(ProposalReviewRuleset.rules))
|
||||
.order_by(desc(ProposalReviewRuleset.created_at))
|
||||
)
|
||||
if active_only:
|
||||
query = query.filter(ProposalReviewRuleset.is_active.is_(True))
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_ruleset(
|
||||
ruleset_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRuleset | None:
|
||||
"""Get a single ruleset by ID with all its rules."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRuleset)
|
||||
.filter(
|
||||
ProposalReviewRuleset.id == ruleset_id,
|
||||
ProposalReviewRuleset.tenant_id == tenant_id,
|
||||
)
|
||||
.options(selectinload(ProposalReviewRuleset.rules))
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def create_ruleset(
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
db_session: Session,
|
||||
description: str | None = None,
|
||||
is_default: bool = False,
|
||||
created_by: UUID | None = None,
|
||||
) -> ProposalReviewRuleset:
|
||||
"""Create a new ruleset."""
|
||||
# If this ruleset is default, un-default any existing default
|
||||
if is_default:
|
||||
_clear_default_ruleset(tenant_id, db_session)
|
||||
|
||||
ruleset = ProposalReviewRuleset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
is_default=is_default,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(ruleset)
|
||||
db_session.flush()
|
||||
logger.info(f"Created ruleset {ruleset.id} '{name}' for tenant {tenant_id}")
|
||||
return ruleset
|
||||
|
||||
|
||||
def update_ruleset(
|
||||
ruleset_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
updates: dict[str, Any],
|
||||
) -> ProposalReviewRuleset | None:
|
||||
"""Update a ruleset. Returns None if not found."""
|
||||
ruleset = get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
return None
|
||||
|
||||
for field, value in updates.items():
|
||||
if field not in _RULESET_UPDATABLE_FIELDS:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, f"Cannot update field: {field}"
|
||||
)
|
||||
if field == "is_default" and value:
|
||||
_clear_default_ruleset(tenant_id, db_session)
|
||||
setattr(ruleset, field, value)
|
||||
|
||||
ruleset.updated_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
return ruleset
|
||||
|
||||
|
||||
def delete_ruleset(
|
||||
ruleset_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Delete a ruleset. Returns False if not found."""
|
||||
ruleset = get_ruleset(ruleset_id, tenant_id, db_session)
|
||||
if not ruleset:
|
||||
return False
|
||||
db_session.delete(ruleset)
|
||||
db_session.flush()
|
||||
logger.info(f"Deleted ruleset {ruleset_id}")
|
||||
return True
|
||||
|
||||
|
||||
def _clear_default_ruleset(tenant_id: str, db_session: Session) -> None:
|
||||
"""Un-default any existing default ruleset for a tenant."""
|
||||
db_session.query(ProposalReviewRuleset).filter(
|
||||
ProposalReviewRuleset.tenant_id == tenant_id,
|
||||
ProposalReviewRuleset.is_default.is_(True),
|
||||
).update({ProposalReviewRuleset.is_default: False})
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Rule CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def list_rules_by_ruleset(
|
||||
ruleset_id: UUID,
|
||||
db_session: Session,
|
||||
active_only: bool = False,
|
||||
) -> list[ProposalReviewRule]:
|
||||
"""List all rules in a ruleset."""
|
||||
query = (
|
||||
db_session.query(ProposalReviewRule)
|
||||
.filter(ProposalReviewRule.ruleset_id == ruleset_id)
|
||||
.order_by(ProposalReviewRule.priority)
|
||||
)
|
||||
if active_only:
|
||||
query = query.filter(ProposalReviewRule.is_active.is_(True))
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_rule(
|
||||
rule_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRule | None:
|
||||
"""Get a single rule by ID."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRule)
|
||||
.filter(ProposalReviewRule.id == rule_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_rule_with_tenant_check(
|
||||
rule_id: UUID,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> ProposalReviewRule | None:
|
||||
"""Get a single rule by ID, validating it belongs to the given tenant.
|
||||
|
||||
Joins with the ruleset table so the tenant check happens in one query,
|
||||
eliminating the race between separate get_rule + get_ruleset calls.
|
||||
"""
|
||||
return (
|
||||
db_session.query(ProposalReviewRule)
|
||||
.join(
|
||||
ProposalReviewRuleset,
|
||||
ProposalReviewRule.ruleset_id == ProposalReviewRuleset.id,
|
||||
)
|
||||
.filter(
|
||||
ProposalReviewRule.id == rule_id,
|
||||
ProposalReviewRuleset.tenant_id == tenant_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def create_rule(
|
||||
ruleset_id: UUID,
|
||||
name: str,
|
||||
rule_type: str,
|
||||
prompt_template: str,
|
||||
db_session: Session,
|
||||
description: str | None = None,
|
||||
category: str | None = None,
|
||||
rule_intent: str = "CHECK",
|
||||
source: str = "MANUAL",
|
||||
authority: str | None = None,
|
||||
is_hard_stop: bool = False,
|
||||
priority: int = 0,
|
||||
refinement_needed: bool = False,
|
||||
refinement_question: str | None = None,
|
||||
) -> ProposalReviewRule:
|
||||
"""Create a new rule within a ruleset."""
|
||||
rule = ProposalReviewRule(
|
||||
ruleset_id=ruleset_id,
|
||||
name=name,
|
||||
description=description,
|
||||
category=category,
|
||||
rule_type=rule_type,
|
||||
rule_intent=rule_intent,
|
||||
prompt_template=prompt_template,
|
||||
source=source,
|
||||
authority=authority,
|
||||
is_hard_stop=is_hard_stop,
|
||||
priority=priority,
|
||||
refinement_needed=refinement_needed,
|
||||
refinement_question=refinement_question,
|
||||
)
|
||||
db_session.add(rule)
|
||||
db_session.flush()
|
||||
logger.info(f"Created rule {rule.id} '{name}' in ruleset {ruleset_id}")
|
||||
return rule
|
||||
|
||||
|
||||
def update_rule(
|
||||
rule_id: UUID,
|
||||
db_session: Session,
|
||||
updates: dict[str, Any],
|
||||
) -> ProposalReviewRule | None:
|
||||
"""Update a rule. Returns None if not found."""
|
||||
rule = get_rule(rule_id, db_session)
|
||||
if not rule:
|
||||
return None
|
||||
|
||||
for field, value in updates.items():
|
||||
if field not in _RULE_UPDATABLE_FIELDS:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, f"Cannot update field: {field}"
|
||||
)
|
||||
setattr(rule, field, value)
|
||||
|
||||
rule.updated_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
return rule
|
||||
|
||||
|
||||
def delete_rule(
|
||||
rule_id: UUID,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Delete a rule. Returns False if not found."""
|
||||
rule = get_rule(rule_id, db_session)
|
||||
if not rule:
|
||||
return False
|
||||
db_session.delete(rule)
|
||||
db_session.flush()
|
||||
logger.info(f"Deleted rule {rule_id}")
|
||||
return True
|
||||
|
||||
|
||||
def bulk_update_rules(
|
||||
rule_ids: list[UUID],
|
||||
action: str,
|
||||
ruleset_id: UUID,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Batch activate/deactivate/delete rules.
|
||||
|
||||
Args:
|
||||
rule_ids: list of rule IDs
|
||||
action: "activate" | "deactivate" | "delete"
|
||||
ruleset_id: scope operations to rules within this ruleset
|
||||
|
||||
Returns:
|
||||
number of rules affected
|
||||
"""
|
||||
base_query = db_session.query(ProposalReviewRule).filter(
|
||||
ProposalReviewRule.id.in_(rule_ids),
|
||||
ProposalReviewRule.ruleset_id == ruleset_id,
|
||||
)
|
||||
|
||||
if action == "delete":
|
||||
count = base_query.delete(synchronize_session="fetch")
|
||||
elif action in ("activate", "deactivate"):
|
||||
count = base_query.update(
|
||||
{
|
||||
ProposalReviewRule.is_active: action == "activate",
|
||||
ProposalReviewRule.updated_at: datetime.now(timezone.utc),
|
||||
},
|
||||
synchronize_session="fetch",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, f"Unknown bulk action: {action}")
|
||||
|
||||
db_session.flush()
|
||||
logger.info(f"Bulk {action} on {count} rules")
|
||||
return count
|
||||
|
||||
|
||||
def count_active_rules(
|
||||
ruleset_id: UUID,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Count active rules in a ruleset."""
|
||||
return (
|
||||
db_session.query(ProposalReviewRule)
|
||||
.filter(
|
||||
ProposalReviewRule.ruleset_id == ruleset_id,
|
||||
ProposalReviewRule.is_active.is_(True),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Proposal Review Engine — AI-powered proposal evaluation."""
|
||||
@@ -1,497 +0,0 @@
|
||||
"""Parses uploaded checklist documents into atomic review rules via LLM.
|
||||
|
||||
Uses a two-pass approach to handle checklists of any size without hitting
|
||||
output token limits:
|
||||
|
||||
Pass 1 — Enumerate: Identify all distinct checklist items from the
|
||||
document (names, categories, sub-checks). This produces a small,
|
||||
bounded output regardless of document size.
|
||||
|
||||
Pass 2 — Decompose: For each identified item, make a focused LLM call
|
||||
to generate atomic review rules with full prompt templates.
|
||||
Each call produces 1–5 rules, well within token limits.
|
||||
|
||||
Callers orchestrate persistence — this module is pure LLM + parsing, no
|
||||
DB access, no callbacks, no threads.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import get_llm_max_output_tokens
|
||||
from onyx.llm.utils import get_model_map
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChecklistItem:
|
||||
"""A single checklist item identified during pass 1."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
sub_checks: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompts — Pass 1 (Enumerate)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ENUMERATE_SYSTEM = """\
|
||||
You are an expert at analyzing institutional review checklists for university \
|
||||
grant offices. Your task is to read a checklist document and identify every \
|
||||
distinct checklist item or section that requires review."""
|
||||
|
||||
_ENUMERATE_USER = """\
|
||||
Read the checklist document below and list every distinct checklist item.
|
||||
|
||||
CHECKLIST DOCUMENT:
|
||||
---
|
||||
{checklist_text}
|
||||
---
|
||||
|
||||
For each item, provide:
|
||||
- **id**: A short identifier derived from the document (e.g., "IR-1", \
|
||||
"KR-3", "Section-A.2"). Invent one if the document doesn't assign one.
|
||||
- **name**: The item's title or heading.
|
||||
- **category**: A display label combining the id and name \
|
||||
(e.g., "IR-2: Regulatory Compliance").
|
||||
- **description**: One sentence summarizing what this item covers.
|
||||
- **sub_checks**: A list of the individual checks or requirements \
|
||||
mentioned under this item. Be thorough — include every distinct \
|
||||
requirement even if the document groups them together.
|
||||
|
||||
Respond with ONLY a valid JSON array:
|
||||
[
|
||||
{{
|
||||
"id": "IR-1",
|
||||
"name": "Institutional and PI Eligibility",
|
||||
"category": "IR-1: Institutional and PI Eligibility Requirements",
|
||||
"description": "Verify institution and PI meet sponsor eligibility.",
|
||||
"sub_checks": ["Institutional eligibility", "PI eligibility", ...]
|
||||
}},
|
||||
...
|
||||
]"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompts — Pass 2 (Decompose one item into rules)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DECOMPOSE_SYSTEM = """\
|
||||
You are an expert at creating AI review rules for university grant proposal \
|
||||
review. Each rule you create will be independently evaluated by an LLM \
|
||||
against a grant proposal. Rules must be atomic (one criterion each) and \
|
||||
self-contained (the prompt template includes all context needed).
|
||||
|
||||
Variable placeholders available for prompt templates:
|
||||
{{{{proposal_text}}}} — full proposal and supporting documents
|
||||
{{{{budget_text}}}} — budget / financial sections
|
||||
{{{{foa_text}}}} — funding opportunity announcement
|
||||
{{{{metadata}}}} — structured metadata (PI, sponsor, etc.)
|
||||
{{{{metadata.FIELD_NAME}}}} — a specific metadata field
|
||||
|
||||
Rule types:
|
||||
DOCUMENT_CHECK — verify presence / content in documents
|
||||
METADATA_CHECK — validate a structured metadata field
|
||||
CROSS_REFERENCE — compare information across documents
|
||||
CUSTOM_NL — natural language evaluation
|
||||
|
||||
Rule intents:
|
||||
CHECK — pass / fail criterion
|
||||
HIGHLIGHT — informational flag (no pass / fail)
|
||||
|
||||
If a rule requires institution-specific info NOT present in the checklist \
|
||||
(IDC rates, mandatory cost categories, local policies, etc.), set \
|
||||
refinement_needed=true and include a refinement_question."""
|
||||
|
||||
_DECOMPOSE_USER = """\
|
||||
Create atomic review rules for the checklist item described below.
|
||||
|
||||
ITEM TO DECOMPOSE:
|
||||
ID: {item_id}
|
||||
Name: {item_name}
|
||||
Category: {item_category}
|
||||
Description: {item_description}
|
||||
Sub-checks: {sub_checks}
|
||||
|
||||
FULL CHECKLIST (for context — only create rules for the item above):
|
||||
---
|
||||
{checklist_text}
|
||||
---
|
||||
|
||||
Generate one rule per sub-check. Each rule object must have:
|
||||
{{
|
||||
"name": "Short descriptive name (max 100 chars)",
|
||||
"description": "What this rule checks",
|
||||
"category": "{item_category}",
|
||||
"rule_type": "DOCUMENT_CHECK | METADATA_CHECK | CROSS_REFERENCE | CUSTOM_NL",
|
||||
"rule_intent": "CHECK | HIGHLIGHT",
|
||||
"prompt_template": "Self-contained prompt with {{{{variable}}}} placeholders.",
|
||||
"refinement_needed": false,
|
||||
"refinement_question": null
|
||||
}}
|
||||
|
||||
Respond with ONLY a valid JSON array of rule objects."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def enumerate_checklist_items(
|
||||
checklist_text: str,
|
||||
llm: LLM,
|
||||
) -> list[ChecklistItem]:
|
||||
"""Pass 1: Identify all distinct checklist items from the document.
|
||||
|
||||
One LLM call. Output is small and bounded regardless of document size.
|
||||
|
||||
Args:
|
||||
checklist_text: Full text extracted from the uploaded checklist file.
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
Ordered list of checklist items found in the document.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the LLM call fails or returns unparseable output.
|
||||
"""
|
||||
user_content = _ENUMERATE_USER.format(checklist_text=checklist_text)
|
||||
messages = [
|
||||
SystemMessage(content=_ENUMERATE_SYSTEM),
|
||||
UserMessage(content=user_content),
|
||||
]
|
||||
|
||||
max_output_tokens = _get_max_output_tokens(llm)
|
||||
|
||||
try:
|
||||
with llm_generation_span(llm, "checklist_enumerate", messages) as gen_span:
|
||||
response = llm.invoke(
|
||||
messages, timeout_override=300, max_tokens=max_output_tokens
|
||||
)
|
||||
record_llm_response(gen_span, response)
|
||||
raw_text = llm_response_to_string(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Pass 1 (enumerate) LLM call failed: {e}")
|
||||
raise RuntimeError(f"Failed to enumerate checklist items: {str(e)}") from e
|
||||
|
||||
parsed = _parse_json_array(raw_text, context="enumerate")
|
||||
|
||||
items: list[ChecklistItem] = []
|
||||
for i, raw in enumerate(parsed):
|
||||
if not isinstance(raw, dict):
|
||||
logger.warning(f"Enumerate: skipping non-dict at index {i}")
|
||||
continue
|
||||
|
||||
item_id = str(raw.get("id", f"ITEM-{i + 1}"))
|
||||
name = raw.get("name")
|
||||
if not name:
|
||||
logger.warning(f"Enumerate: skipping item at index {i} (no name)")
|
||||
continue
|
||||
|
||||
items.append(
|
||||
ChecklistItem(
|
||||
id=item_id,
|
||||
name=str(name),
|
||||
category=str(raw.get("category", f"{item_id}: {name}")),
|
||||
description=str(raw.get("description", "")),
|
||||
sub_checks=[str(s) for s in raw.get("sub_checks", [])],
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def decompose_checklist_item(
|
||||
item: ChecklistItem,
|
||||
checklist_text: str,
|
||||
llm: LLM,
|
||||
) -> list[dict]:
|
||||
"""Pass 2: Decompose one checklist item into atomic review rules.
|
||||
|
||||
One LLM call. Output is bounded (1–10 rules per item).
|
||||
|
||||
Args:
|
||||
item: The checklist item to decompose.
|
||||
checklist_text: Full checklist text (passed as context).
|
||||
llm: The LLM instance to use.
|
||||
|
||||
Returns:
|
||||
List of validated rule dicts for this item.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the LLM call fails or returns unparseable output.
|
||||
"""
|
||||
sub_checks_str = "\n".join(f" - {s}" for s in item.sub_checks) or " (none listed)"
|
||||
|
||||
user_content = _DECOMPOSE_USER.format(
|
||||
item_id=item.id,
|
||||
item_name=item.name,
|
||||
item_category=item.category,
|
||||
item_description=item.description,
|
||||
sub_checks=sub_checks_str,
|
||||
checklist_text=checklist_text,
|
||||
)
|
||||
messages = [
|
||||
SystemMessage(content=_DECOMPOSE_SYSTEM),
|
||||
UserMessage(content=user_content),
|
||||
]
|
||||
|
||||
max_output_tokens = _get_max_output_tokens(llm)
|
||||
|
||||
try:
|
||||
with llm_generation_span(
|
||||
llm, f"checklist_decompose_{item.id}", messages
|
||||
) as gen_span:
|
||||
response = llm.invoke(
|
||||
messages, timeout_override=300, max_tokens=max_output_tokens
|
||||
)
|
||||
record_llm_response(gen_span, response)
|
||||
raw_text = llm_response_to_string(response)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"LLM call failed for item '{item.name}': {str(e)}") from e
|
||||
|
||||
parsed = _parse_json_array(raw_text, context=f"decompose[{item.id}]")
|
||||
|
||||
rules: list[dict] = []
|
||||
for i, raw_rule in enumerate(parsed):
|
||||
if not isinstance(raw_rule, dict):
|
||||
continue
|
||||
rule = _validate_rule(raw_rule, i)
|
||||
if rule:
|
||||
if not rule["category"]:
|
||||
rule["category"] = item.category
|
||||
rules.append(rule)
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompts — Refinement (single rule)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REFINE_SYSTEM = """\
|
||||
You are an expert at creating AI review rules for university grant proposal \
|
||||
review. You are refining a rule that was previously flagged as needing \
|
||||
institution-specific information. The user has now provided that information.
|
||||
|
||||
Variable placeholders available for prompt templates:
|
||||
{{{{proposal_text}}}} — full proposal and supporting documents
|
||||
{{{{budget_text}}}} — budget / financial sections
|
||||
{{{{foa_text}}}} — funding opportunity announcement
|
||||
{{{{metadata}}}} — structured metadata (PI, sponsor, etc.)
|
||||
{{{{metadata.FIELD_NAME}}}} — a specific metadata field
|
||||
|
||||
Rule types:
|
||||
DOCUMENT_CHECK — verify presence / content in documents
|
||||
METADATA_CHECK — validate a structured metadata field
|
||||
CROSS_REFERENCE — compare information across documents
|
||||
CUSTOM_NL — natural language evaluation
|
||||
|
||||
Rule intents:
|
||||
CHECK — pass / fail criterion
|
||||
HIGHLIGHT — informational flag (no pass / fail)"""
|
||||
|
||||
_REFINE_USER = """\
|
||||
The following rule was imported from a checklist but flagged as needing \
|
||||
additional information before it can be used.
|
||||
|
||||
CURRENT RULE:
|
||||
Name: {rule_name}
|
||||
Description: {rule_description}
|
||||
Prompt Template: {rule_prompt_template}
|
||||
|
||||
QUESTION THAT WAS ASKED:
|
||||
{refinement_question}
|
||||
|
||||
USER'S ANSWER:
|
||||
{user_answer}
|
||||
|
||||
Using the user's answer, produce a refined version of this rule. \
|
||||
Incorporate the institution-specific information into the prompt_template \
|
||||
so the rule is fully self-contained and no longer needs refinement.
|
||||
|
||||
Respond with ONLY a single JSON object (not an array):
|
||||
{{
|
||||
"name": "Short descriptive name (max 100 chars)",
|
||||
"description": "What this rule checks",
|
||||
"rule_type": "DOCUMENT_CHECK | METADATA_CHECK | CROSS_REFERENCE | CUSTOM_NL",
|
||||
"rule_intent": "CHECK | HIGHLIGHT",
|
||||
"prompt_template": "Refined self-contained prompt with {{{{variable}}}} placeholders.",
|
||||
"refinement_needed": false,
|
||||
"refinement_question": null
|
||||
}}"""
|
||||
|
||||
|
||||
def refine_rule(
|
||||
rule_name: str,
|
||||
rule_description: str | None,
|
||||
rule_prompt_template: str,
|
||||
refinement_question: str,
|
||||
user_answer: str,
|
||||
llm: LLM,
|
||||
) -> dict:
|
||||
"""Refine a single rule using the user's answer to the refinement question.
|
||||
|
||||
One LLM call. Returns a validated rule dict with refinement_needed=False.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the LLM call fails or returns unparseable output.
|
||||
"""
|
||||
user_content = _REFINE_USER.format(
|
||||
rule_name=rule_name,
|
||||
rule_description=rule_description or "(none)",
|
||||
rule_prompt_template=rule_prompt_template,
|
||||
refinement_question=refinement_question,
|
||||
user_answer=user_answer,
|
||||
)
|
||||
messages = [
|
||||
SystemMessage(content=_REFINE_SYSTEM),
|
||||
UserMessage(content=user_content),
|
||||
]
|
||||
|
||||
try:
|
||||
with llm_generation_span(llm, "checklist_refine_rule", messages) as gen_span:
|
||||
response = llm.invoke(messages, timeout_override=120)
|
||||
record_llm_response(gen_span, response)
|
||||
raw_text = llm_response_to_string(response)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"LLM call failed during rule refinement: {str(e)}") from e
|
||||
|
||||
# Parse the single JSON object (strip code fences)
|
||||
text = raw_text.strip()
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
|
||||
text = re.sub(r"\n?```\s*$", "", text)
|
||||
text = text.strip()
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(f"LLM returned invalid JSON during refinement: {e}") from e
|
||||
|
||||
if isinstance(parsed, list):
|
||||
if not parsed:
|
||||
raise RuntimeError("LLM returned an empty array during refinement")
|
||||
parsed = parsed[0]
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
raise RuntimeError("LLM returned non-object JSON during refinement")
|
||||
|
||||
rule = _validate_rule(parsed, 0)
|
||||
if not rule:
|
||||
raise RuntimeError("LLM returned an invalid rule during refinement")
|
||||
|
||||
# Force refinement_needed to False
|
||||
rule["refinement_needed"] = False
|
||||
rule["refinement_question"] = None
|
||||
return rule
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_max_output_tokens(llm: LLM) -> int:
|
||||
"""Look up the model's max output tokens from litellm's model cost map."""
|
||||
try:
|
||||
model_map = get_model_map()
|
||||
return get_llm_max_output_tokens(
|
||||
model_map=model_map,
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve max output tokens: {e}")
|
||||
return int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||
|
||||
|
||||
def _parse_json_array(raw_text: str, context: str) -> list:
|
||||
"""Parse an LLM response as a JSON array, stripping code fences."""
|
||||
text = raw_text.strip()
|
||||
|
||||
if text.startswith("```"):
|
||||
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
|
||||
text = re.sub(r"\n?```\s*$", "", text)
|
||||
text = text.strip()
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[{context}] Failed to parse JSON: {e}")
|
||||
logger.debug(f"[{context}] Raw LLM response: {text[:500]}...")
|
||||
raise RuntimeError(
|
||||
f"LLM returned invalid JSON during {context}. "
|
||||
"Please try the import again."
|
||||
) from e
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
raise RuntimeError(
|
||||
f"LLM returned non-array JSON during {context}. " "Expected a list."
|
||||
)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def _validate_rule(raw_rule: dict, index: int) -> dict | None:
|
||||
"""Validate and normalize a single parsed rule dict."""
|
||||
valid_types = {
|
||||
"DOCUMENT_CHECK",
|
||||
"METADATA_CHECK",
|
||||
"CROSS_REFERENCE",
|
||||
"CUSTOM_NL",
|
||||
}
|
||||
valid_intents = {"CHECK", "HIGHLIGHT"}
|
||||
|
||||
name = raw_rule.get("name")
|
||||
if not name:
|
||||
logger.warning(f"Rule at index {index} missing 'name', skipping")
|
||||
return None
|
||||
|
||||
prompt_template = raw_rule.get("prompt_template")
|
||||
if not prompt_template:
|
||||
logger.warning(f"Rule '{name}' missing 'prompt_template', skipping")
|
||||
return None
|
||||
|
||||
rule_type = str(raw_rule.get("rule_type", "CUSTOM_NL")).upper()
|
||||
if rule_type not in valid_types:
|
||||
rule_type = "CUSTOM_NL"
|
||||
|
||||
rule_intent = str(raw_rule.get("rule_intent", "CHECK")).upper()
|
||||
if rule_intent not in valid_intents:
|
||||
rule_intent = "CHECK"
|
||||
|
||||
return {
|
||||
"name": str(name)[:200],
|
||||
"description": raw_rule.get("description"),
|
||||
"category": raw_rule.get("category"),
|
||||
"rule_type": rule_type,
|
||||
"rule_intent": rule_intent,
|
||||
"prompt_template": str(prompt_template),
|
||||
"refinement_needed": bool(raw_rule.get("refinement_needed", False)),
|
||||
"refinement_question": (
|
||||
str(raw_rule["refinement_question"])
|
||||
if raw_rule.get("refinement_question")
|
||||
else None
|
||||
),
|
||||
}
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Assembles all available text content for a proposal to pass to rule evaluation.
|
||||
|
||||
V1 LIMITATION: Document body text (the main text content extracted by connectors)
|
||||
is stored in Vespa, not in the PostgreSQL Document table. The DB row only stores
|
||||
metadata (semantic_id, link, doc_metadata, primary_owners, etc.). For Jira tickets,
|
||||
the Description and Comments text are indexed into Vespa during connector runs and
|
||||
are NOT accessible here without a Vespa query.
|
||||
|
||||
As a result, the primary source of rich text for rule evaluation in V1 is:
|
||||
- Manually uploaded documents (proposal_review_document.extracted_text)
|
||||
- Structured metadata from the Document row's doc_metadata JSONB column
|
||||
- For Jira tickets: the connector populates doc_metadata with field values,
|
||||
which often includes Description, Status, Priority, Assignee, etc.
|
||||
|
||||
Future improvement: add a Vespa retrieval step to fetch indexed text chunks for
|
||||
the parent document and its attachments.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Document
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Metadata keys from Jira connector that commonly carry useful text content.
|
||||
# These are extracted from doc_metadata and presented as labeled sections to
|
||||
# give the LLM more signal when evaluating rules.
|
||||
_JIRA_TEXT_METADATA_KEYS = [
|
||||
"description",
|
||||
"summary",
|
||||
"comment",
|
||||
"comments",
|
||||
"acceptance_criteria",
|
||||
"story_points",
|
||||
"priority",
|
||||
"status",
|
||||
"resolution",
|
||||
"issue_type",
|
||||
"labels",
|
||||
"components",
|
||||
"fix_versions",
|
||||
"affects_versions",
|
||||
"environment",
|
||||
"assignee",
|
||||
"reporter",
|
||||
"creator",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProposalContext:
|
||||
"""All text and metadata context assembled for rule evaluation."""
|
||||
|
||||
proposal_text: str # concatenated text from all documents
|
||||
budget_text: str # best-effort budget section extraction
|
||||
foa_text: str # FOA content (auto-fetched or uploaded)
|
||||
metadata: dict # structured metadata from Document.doc_metadata
|
||||
jira_key: str # for display/reference
|
||||
metadata_raw: dict = field(default_factory=dict) # full unresolved metadata
|
||||
|
||||
|
||||
def get_proposal_context(
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
) -> ProposalContext:
|
||||
"""Assemble context for rule evaluation.
|
||||
|
||||
Gathers text from three sources:
|
||||
1. Jira ticket content (from Document.semantic_id + doc_metadata)
|
||||
2. Jira attachments (child Documents linked by ID prefix convention)
|
||||
3. Manually uploaded documents (from proposal_review_document.extracted_text)
|
||||
|
||||
For MVP, returns full text of everything. Future: smart section selection.
|
||||
"""
|
||||
# 1. Get the proposal record to find the linked document_id
|
||||
proposal = (
|
||||
db_session.query(ProposalReviewProposal)
|
||||
.filter(ProposalReviewProposal.id == proposal_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if not proposal:
|
||||
logger.warning(f"Proposal {proposal_id} not found during context assembly")
|
||||
return ProposalContext(
|
||||
proposal_text="",
|
||||
budget_text="",
|
||||
foa_text="",
|
||||
metadata={},
|
||||
jira_key="",
|
||||
metadata_raw={},
|
||||
)
|
||||
|
||||
# 2. Fetch the parent Document (Jira ticket)
|
||||
parent_doc = (
|
||||
db_session.query(Document)
|
||||
.filter(Document.id == proposal.document_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
jira_key = ""
|
||||
metadata: dict = {}
|
||||
all_text_parts: list[str] = []
|
||||
budget_parts: list[str] = []
|
||||
foa_parts: list[str] = []
|
||||
|
||||
if parent_doc:
|
||||
jira_key = parent_doc.semantic_id or ""
|
||||
metadata = parent_doc.doc_metadata or {}
|
||||
|
||||
# Build text from DB-available fields. The actual ticket body text lives
|
||||
# in Vespa and is not accessible here. The doc_metadata JSONB column
|
||||
# often contains structured Jira fields that the connector extracted.
|
||||
parent_text = _build_parent_document_text(parent_doc)
|
||||
if parent_text:
|
||||
all_text_parts.append(parent_text)
|
||||
|
||||
# 3. Look for child Documents (Jira attachments).
|
||||
# Jira attachment Documents have IDs of the form:
|
||||
# "{parent_jira_url}/attachments/{attachment_id}"
|
||||
# We find them via ID prefix match.
|
||||
#
|
||||
# V1 LIMITATION: child document text content is in Vespa, not in the
|
||||
# DB. We can only extract metadata (filename, mime type, etc.) from
|
||||
# the Document row. The actual attachment text is not available here
|
||||
# without a Vespa query. See module docstring for details.
|
||||
child_docs = _find_child_documents(parent_doc, db_session)
|
||||
if child_docs:
|
||||
logger.info(
|
||||
f"Found {len(child_docs)} child documents for {jira_key}. "
|
||||
f"Note: their text content is in Vespa and only metadata is "
|
||||
f"available for rule evaluation."
|
||||
)
|
||||
for child_doc in child_docs:
|
||||
child_text = _build_child_document_text(child_doc)
|
||||
if child_text:
|
||||
all_text_parts.append(child_text)
|
||||
_classify_child_text(child_doc, child_text, budget_parts, foa_parts)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Parent Document not found for proposal {proposal_id} "
|
||||
f"(document_id={proposal.document_id}). "
|
||||
f"Context will rely on manually uploaded documents only."
|
||||
)
|
||||
|
||||
# 4. Fetch manually uploaded documents from proposal_review_document.
|
||||
# This is the PRIMARY source of rich text content for V1 since the
|
||||
# extracted_text column holds the full document content.
|
||||
manual_docs = (
|
||||
db_session.query(ProposalReviewDocument)
|
||||
.filter(ProposalReviewDocument.proposal_id == proposal_id)
|
||||
.order_by(ProposalReviewDocument.created_at)
|
||||
.all()
|
||||
)
|
||||
for doc in manual_docs:
|
||||
if doc.extracted_text:
|
||||
all_text_parts.append(
|
||||
f"--- Document: {doc.file_name} (role: {doc.document_role}) ---\n"
|
||||
f"{doc.extracted_text}"
|
||||
)
|
||||
# Classify by role
|
||||
role_upper = (doc.document_role or "").upper()
|
||||
if role_upper == "BUDGET" or _is_budget_filename(doc.file_name):
|
||||
budget_parts.append(doc.extracted_text)
|
||||
elif role_upper == "FOA":
|
||||
foa_parts.append(doc.extracted_text)
|
||||
|
||||
return ProposalContext(
|
||||
proposal_text="\n\n".join(all_text_parts) if all_text_parts else "",
|
||||
budget_text="\n\n".join(budget_parts) if budget_parts else "",
|
||||
foa_text="\n\n".join(foa_parts) if foa_parts else "",
|
||||
metadata=metadata,
|
||||
jira_key=jira_key,
|
||||
metadata_raw=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _build_parent_document_text(doc: Document) -> str:
|
||||
"""Build text representation from a parent Document row (Jira ticket).
|
||||
|
||||
The Document table does NOT store the ticket body text -- that lives in Vespa.
|
||||
What we DO have access to:
|
||||
- semantic_id: typically "{ISSUE_KEY}: {summary}"
|
||||
- link: URL to the Jira ticket
|
||||
- doc_metadata: JSONB with structured fields from the connector (may include
|
||||
description, status, priority, assignee, custom fields, etc.)
|
||||
- primary_owners / secondary_owners: people associated with the document
|
||||
|
||||
We extract all available metadata and present it as labeled sections to
|
||||
maximize the signal available to the LLM for rule evaluation.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
if doc.semantic_id:
|
||||
parts.append(f"Document: {doc.semantic_id}")
|
||||
if doc.link:
|
||||
parts.append(f"Link: {doc.link}")
|
||||
|
||||
# Include owner information which may be useful for compliance checks
|
||||
if doc.primary_owners:
|
||||
parts.append(f"Primary Owners: {', '.join(doc.primary_owners)}")
|
||||
if doc.secondary_owners:
|
||||
parts.append(f"Secondary Owners: {', '.join(doc.secondary_owners)}")
|
||||
|
||||
# doc_metadata contains structured data from the Jira connector.
|
||||
# Extract well-known text-bearing fields first, then include the rest.
|
||||
if doc.doc_metadata:
|
||||
metadata = doc.doc_metadata
|
||||
|
||||
# Extract well-known Jira fields as labeled sections
|
||||
for key in _JIRA_TEXT_METADATA_KEYS:
|
||||
value = metadata.get(key)
|
||||
if value is not None and value != "" and value != []:
|
||||
label = key.replace("_", " ").title()
|
||||
if isinstance(value, list):
|
||||
parts.append(f"{label}: {', '.join(str(v) for v in value)}")
|
||||
elif isinstance(value, dict):
|
||||
parts.append(
|
||||
f"{label}:\n{json.dumps(value, indent=2, default=str)}"
|
||||
)
|
||||
else:
|
||||
parts.append(f"{label}: {value}")
|
||||
|
||||
# Include any remaining metadata keys not in the well-known set,
|
||||
# so custom fields and connector-specific data are not lost.
|
||||
remaining = {
|
||||
k: v
|
||||
for k, v in metadata.items()
|
||||
if k.lower() not in _JIRA_TEXT_METADATA_KEYS
|
||||
and v is not None
|
||||
and v != ""
|
||||
and v != []
|
||||
}
|
||||
if remaining:
|
||||
parts.append(
|
||||
f"Additional Metadata:\n"
|
||||
f"{json.dumps(remaining, indent=2, default=str)}"
|
||||
)
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _build_child_document_text(doc: Document) -> str:
|
||||
"""Build text representation from a child Document row (Jira attachment).
|
||||
|
||||
V1 LIMITATION: The actual extracted text of the attachment lives in Vespa,
|
||||
not in the Document table. We can only present the metadata that the
|
||||
connector stored in doc_metadata (filename, mime type, size, parent ticket).
|
||||
|
||||
This means the LLM knows an attachment EXISTS and its metadata, but cannot
|
||||
read its contents. Future versions should add a Vespa retrieval step.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
if doc.semantic_id:
|
||||
parts.append(f"Attachment: {doc.semantic_id}")
|
||||
if doc.link:
|
||||
parts.append(f"Link: {doc.link}")
|
||||
|
||||
# Child document metadata typically includes:
|
||||
# parent_ticket, attachment_filename, attachment_mime_type, attachment_size
|
||||
if doc.doc_metadata:
|
||||
for key, value in doc.doc_metadata.items():
|
||||
if value is not None and value != "":
|
||||
label = key.replace("_", " ").title()
|
||||
parts.append(f"{label}: {value}")
|
||||
|
||||
if not parts:
|
||||
return ""
|
||||
|
||||
# Note the limitation inline for the LLM context
|
||||
parts.append(
|
||||
"[Note: Full attachment text is indexed in Vespa and not available "
|
||||
"in this context. Upload the document manually for full text analysis.]"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _find_child_documents(parent_doc: Document, db_session: Session) -> list[Document]:
|
||||
"""Find child Documents linked to the parent (e.g. Jira attachments).
|
||||
|
||||
Jira attachments are indexed as separate Document rows whose ID follows
|
||||
the convention: "{parent_document_id}/attachments/{attachment_id}".
|
||||
The parent_document_id for Jira is the full URL to the issue, e.g.
|
||||
"https://jira.example.com/browse/PROJ-123".
|
||||
|
||||
V1 LIMITATION: These child Document rows only contain metadata in the DB.
|
||||
Their actual extracted text content is stored in Vespa. To read the
|
||||
attachment text, a Vespa query would be required. This is not implemented
|
||||
in V1 -- officers should upload key documents manually for full text
|
||||
analysis.
|
||||
"""
|
||||
if not parent_doc.id:
|
||||
return []
|
||||
|
||||
# Child documents have IDs that start with the parent document's ID
|
||||
# followed by a path segment (e.g., /attachments/12345)
|
||||
# Escape LIKE wildcards in the document ID
|
||||
escaped_id = parent_doc.id.replace("%", r"\%").replace("_", r"\_")
|
||||
child_docs = (
|
||||
db_session.query(Document)
|
||||
.filter(
|
||||
Document.id.like(f"{escaped_id}/%"),
|
||||
Document.id != parent_doc.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return child_docs
|
||||
|
||||
|
||||
def _classify_child_text(
|
||||
doc: Document,
|
||||
text: str,
|
||||
budget_parts: list[str],
|
||||
foa_parts: list[str],
|
||||
) -> None:
|
||||
"""Best-effort classification of child document text into budget or FOA."""
|
||||
semantic_id = (doc.semantic_id or "").lower()
|
||||
|
||||
if _is_budget_filename(semantic_id):
|
||||
budget_parts.append(text)
|
||||
elif any(
|
||||
term in semantic_id
|
||||
for term in ["foa", "funding opportunity", "rfa", "solicitation", "nofo"]
|
||||
):
|
||||
foa_parts.append(text)
|
||||
|
||||
|
||||
def _is_budget_filename(filename: str) -> bool:
|
||||
"""Check if a filename suggests budget content."""
|
||||
lower = (filename or "").lower()
|
||||
return any(term in lower for term in ["budget", "cost", "financial", "expenditure"])
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Auto-fetches Funding Opportunity Announcements using Onyx web search infrastructure."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewDocument
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Map known opportunity ID prefixes to federal agency domains
|
||||
_AGENCY_DOMAINS: dict[str, str] = {
|
||||
"RFA": "grants.nih.gov",
|
||||
"PA": "grants.nih.gov",
|
||||
"PAR": "grants.nih.gov",
|
||||
"R01": "grants.nih.gov",
|
||||
"R21": "grants.nih.gov",
|
||||
"U01": "grants.nih.gov",
|
||||
"NOT": "grants.nih.gov",
|
||||
"NSF": "nsf.gov",
|
||||
"DE-FOA": "energy.gov",
|
||||
"HRSA": "hrsa.gov",
|
||||
"W911": "grants.gov", # DoD
|
||||
"FA": "grants.gov", # Air Force
|
||||
"N00": "grants.gov", # Navy
|
||||
"NOFO": "grants.gov",
|
||||
}
|
||||
|
||||
|
||||
def fetch_foa(
|
||||
opportunity_id: str,
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Fetch FOA content given an opportunity ID.
|
||||
|
||||
1. Determine domain from ID prefix (RFA/PA -> nih.gov, NSF -> nsf.gov, etc.)
|
||||
2. Build search query
|
||||
3. Call Onyx web search provider
|
||||
4. Fetch full content from best URL
|
||||
5. Save as proposal_review_document with role=FOA
|
||||
6. Return extracted text or None
|
||||
|
||||
If the web search provider is not configured, logs a warning and returns None.
|
||||
"""
|
||||
if not opportunity_id or not opportunity_id.strip():
|
||||
logger.debug("No opportunity_id provided, skipping FOA fetch")
|
||||
return None
|
||||
|
||||
opportunity_id = opportunity_id.strip()
|
||||
|
||||
# Check if we already have an FOA document for this proposal
|
||||
existing_foa = (
|
||||
db_session.query(ProposalReviewDocument)
|
||||
.filter(
|
||||
ProposalReviewDocument.proposal_id == proposal_id,
|
||||
ProposalReviewDocument.document_role == "FOA",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing_foa and existing_foa.extracted_text:
|
||||
logger.info(
|
||||
f"FOA document already exists for proposal {proposal_id}, skipping fetch"
|
||||
)
|
||||
return existing_foa.extracted_text
|
||||
|
||||
# Determine search domain from opportunity ID prefix
|
||||
site_domain = _determine_domain(opportunity_id)
|
||||
|
||||
# Build search query
|
||||
search_query = f"{opportunity_id} funding opportunity announcement"
|
||||
if site_domain:
|
||||
search_query = f"site:{site_domain} {opportunity_id}"
|
||||
|
||||
# Try to get the web search provider
|
||||
try:
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
|
||||
provider = get_default_provider()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load web search provider: {e}")
|
||||
provider = None
|
||||
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
"No web search provider configured. Cannot auto-fetch FOA. "
|
||||
"Configure a web search provider in Admin settings to enable this feature."
|
||||
)
|
||||
return None
|
||||
|
||||
# Search for the FOA
|
||||
try:
|
||||
results = provider.search(search_query)
|
||||
except Exception as e:
|
||||
logger.error(f"Web search failed for FOA '{opportunity_id}': {e}")
|
||||
return None
|
||||
|
||||
if not results:
|
||||
logger.info(f"No search results found for FOA '{opportunity_id}'")
|
||||
return None
|
||||
|
||||
# Pick the best result URL
|
||||
best_url = str(results[0].link)
|
||||
logger.info(f"Fetching FOA content from: {best_url}")
|
||||
|
||||
# Fetch full content from the URL
|
||||
try:
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
OnyxWebCrawler,
|
||||
)
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
contents = crawler.contents([best_url])
|
||||
|
||||
if (
|
||||
not contents
|
||||
or not contents[0].scrape_successful
|
||||
or not contents[0].full_content
|
||||
):
|
||||
logger.warning(f"No content extracted from FOA URL: {best_url}")
|
||||
return None
|
||||
|
||||
foa_text = contents[0].full_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch FOA content from {best_url}: {e}")
|
||||
return None
|
||||
|
||||
# Save as a proposal_review_document with role=FOA
|
||||
try:
|
||||
foa_doc = ProposalReviewDocument(
|
||||
proposal_id=proposal_id,
|
||||
file_name=f"FOA_{opportunity_id}.html",
|
||||
file_type="HTML",
|
||||
document_role="FOA",
|
||||
extracted_text=foa_text,
|
||||
# uploaded_by is None for auto-fetched documents
|
||||
)
|
||||
db_session.add(foa_doc)
|
||||
db_session.flush()
|
||||
logger.info(
|
||||
f"Saved FOA document for proposal {proposal_id} "
|
||||
f"(opportunity_id={opportunity_id}, {len(foa_text)} chars)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save FOA document: {e}")
|
||||
# Still return the text even if save fails
|
||||
return foa_text
|
||||
|
||||
return foa_text
|
||||
|
||||
|
||||
def _determine_domain(opportunity_id: str) -> str | None:
|
||||
"""Determine the likely agency domain from the opportunity ID prefix."""
|
||||
upper_id = opportunity_id.upper()
|
||||
|
||||
for prefix, domain in _AGENCY_DOMAINS.items():
|
||||
if upper_id.startswith(prefix):
|
||||
return domain
|
||||
|
||||
# If it looks like a grants.gov number (numeric), try grants.gov
|
||||
if opportunity_id.replace("-", "").isdigit():
|
||||
return "grants.gov"
|
||||
|
||||
return None
|
||||
@@ -1,391 +0,0 @@
|
||||
"""Writes officer decisions back to Jira."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_connector_credential_pair_for_connector,
|
||||
)
|
||||
from onyx.db.models import Document
|
||||
from onyx.server.features.proposal_review.db import config as config_db
|
||||
from onyx.server.features.proposal_review.db import decisions as decisions_db
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
from onyx.server.features.proposal_review.db import proposals as proposals_db
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewFinding
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewProposal
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sync_to_jira(
|
||||
proposal_id: UUID,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Write the officer's final decision back to Jira.
|
||||
|
||||
Performs up to 3 Jira API operations:
|
||||
1. PUT custom fields (decision, completion %)
|
||||
2. POST transition (move to configured column)
|
||||
3. POST comment (structured review summary)
|
||||
|
||||
Then marks the proposal as jira_synced.
|
||||
|
||||
Raises:
|
||||
ValueError: If required config/data is missing.
|
||||
RuntimeError: If Jira API calls fail.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Load proposal
|
||||
proposal = proposals_db.get_proposal(proposal_id, tenant_id, db_session)
|
||||
if not proposal:
|
||||
raise ValueError(f"Proposal {proposal_id} not found")
|
||||
|
||||
if not proposal.decision_at:
|
||||
raise ValueError(f"No decision found for proposal {proposal_id}")
|
||||
|
||||
if proposal.jira_synced:
|
||||
logger.info(f"Decision for proposal {proposal_id} already synced to Jira")
|
||||
return
|
||||
|
||||
# Load tenant config for Jira settings
|
||||
config = config_db.get_config(tenant_id, db_session)
|
||||
if not config:
|
||||
raise ValueError("Proposal review config not found for this tenant")
|
||||
|
||||
if not config.jira_connector_id:
|
||||
raise ValueError(
|
||||
"No Jira connector configured. Set jira_connector_id in proposal review settings."
|
||||
)
|
||||
|
||||
writeback_config = config.jira_writeback or {}
|
||||
|
||||
# Get the Jira issue key from the linked Document
|
||||
parent_doc = (
|
||||
db_session.query(Document)
|
||||
.filter(Document.id == proposal.document_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if not parent_doc:
|
||||
raise ValueError(f"Linked document {proposal.document_id} not found")
|
||||
|
||||
# semantic_id is formatted as "KEY-123: Summary text" by the Jira connector.
|
||||
# Extract just the issue key (everything before the first colon).
|
||||
raw_id = parent_doc.semantic_id
|
||||
if not raw_id:
|
||||
raise ValueError(
|
||||
f"Document {proposal.document_id} has no semantic_id (Jira issue key)"
|
||||
)
|
||||
issue_key = raw_id.split(":")[0].strip()
|
||||
|
||||
# Get Jira credentials from the connector
|
||||
jira_base_url, auth_headers = _get_jira_credentials(
|
||||
config.jira_connector_id, db_session
|
||||
)
|
||||
|
||||
# Get findings for the summary
|
||||
latest_run = findings_db.get_latest_review_run(proposal_id, db_session)
|
||||
all_findings: list[ProposalReviewFinding] = []
|
||||
if latest_run:
|
||||
all_findings = findings_db.list_findings_by_run(latest_run.id, db_session)
|
||||
|
||||
# Calculate summary counts
|
||||
verdict_counts = _count_verdicts(all_findings)
|
||||
|
||||
# Operation 1: Update custom fields
|
||||
_update_custom_fields(
|
||||
jira_base_url=jira_base_url,
|
||||
auth_headers=auth_headers,
|
||||
issue_key=issue_key,
|
||||
decision=proposal.status,
|
||||
verdict_counts=verdict_counts,
|
||||
writeback_config=writeback_config,
|
||||
)
|
||||
|
||||
# Operation 2: Transition the issue
|
||||
_transition_issue(
|
||||
jira_base_url=jira_base_url,
|
||||
auth_headers=auth_headers,
|
||||
issue_key=issue_key,
|
||||
decision=proposal.status,
|
||||
writeback_config=writeback_config,
|
||||
)
|
||||
|
||||
# Operation 3: Post review summary comment
|
||||
_post_comment(
|
||||
jira_base_url=jira_base_url,
|
||||
auth_headers=auth_headers,
|
||||
issue_key=issue_key,
|
||||
proposal=proposal,
|
||||
verdict_counts=verdict_counts,
|
||||
findings=all_findings,
|
||||
)
|
||||
|
||||
# Mark as synced
|
||||
decisions_db.mark_proposal_jira_synced(proposal_id, tenant_id, db_session)
|
||||
db_session.flush()
|
||||
|
||||
logger.info(
|
||||
f"Successfully synced decision for proposal {proposal_id} to Jira issue {issue_key}"
|
||||
)
|
||||
|
||||
|
||||
def _get_jira_credentials(
|
||||
connector_id: int,
|
||||
db_session: Session,
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Extract Jira base URL and auth headers from the connector's credentials.
|
||||
|
||||
Returns:
|
||||
Tuple of (jira_base_url, auth_headers_dict).
|
||||
"""
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if not connector:
|
||||
raise ValueError(f"Jira connector {connector_id} not found")
|
||||
|
||||
# Get the connector's credential pair
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No credential pair found for connector {connector_id}")
|
||||
|
||||
# Extract credentials — guard against missing credential_json
|
||||
cred_json = cc_pair.credential.credential_json
|
||||
if cred_json is None:
|
||||
raise ValueError(f"No credential_json for connector {connector_id}")
|
||||
credentials = cred_json.get_value(apply_mask=False)
|
||||
if not credentials:
|
||||
raise ValueError(f"Empty credentials for connector {connector_id}")
|
||||
|
||||
# Extract Jira base URL from connector config
|
||||
connector_config = connector.connector_specific_config or {}
|
||||
jira_base_url = connector_config.get("jira_base_url", "")
|
||||
|
||||
if not jira_base_url:
|
||||
raise ValueError("Could not determine Jira base URL from connector config")
|
||||
|
||||
# Build auth headers
|
||||
api_token = credentials.get("jira_api_token", "")
|
||||
email = credentials.get("jira_user_email")
|
||||
|
||||
if email:
|
||||
# Cloud auth: Basic auth with email:token
|
||||
import base64
|
||||
|
||||
auth_string = base64.b64encode(f"{email}:{api_token}".encode()).decode()
|
||||
auth_headers = {
|
||||
"Authorization": f"Basic {auth_string}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
else:
|
||||
# Server auth: Bearer token
|
||||
auth_headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
return jira_base_url, auth_headers
|
||||
|
||||
|
||||
def _count_verdicts(findings: list[ProposalReviewFinding]) -> dict[str, int]:
|
||||
"""Count findings by verdict."""
|
||||
counts: dict[str, int] = {
|
||||
"PASS": 0,
|
||||
"FAIL": 0,
|
||||
"FLAG": 0,
|
||||
"NEEDS_REVIEW": 0,
|
||||
"NOT_APPLICABLE": 0,
|
||||
}
|
||||
for f in findings:
|
||||
verdict = f.verdict.upper() if f.verdict else "NEEDS_REVIEW"
|
||||
counts[verdict] = counts.get(verdict, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
def _update_custom_fields(
|
||||
jira_base_url: str,
|
||||
auth_headers: dict[str, str],
|
||||
issue_key: str,
|
||||
decision: str,
|
||||
verdict_counts: dict[str, int],
|
||||
writeback_config: dict,
|
||||
) -> None:
|
||||
"""PUT custom fields on the Jira issue (decision, completion %)."""
|
||||
decision_field = writeback_config.get("decision_field_id")
|
||||
completion_field = writeback_config.get("completion_field_id")
|
||||
|
||||
if not decision_field and not completion_field:
|
||||
logger.debug("No custom field IDs configured for Jira writeback, skipping")
|
||||
return
|
||||
|
||||
fields: dict = {}
|
||||
if decision_field:
|
||||
fields[decision_field] = decision
|
||||
if completion_field:
|
||||
total = sum(verdict_counts.values())
|
||||
completed = total - verdict_counts.get("NEEDS_REVIEW", 0)
|
||||
pct = (completed / total * 100) if total > 0 else 0
|
||||
fields[completion_field] = round(pct, 1)
|
||||
|
||||
url = f"{jira_base_url}/rest/api/3/issue/{issue_key}"
|
||||
payload = {"fields": fields}
|
||||
|
||||
try:
|
||||
resp = requests.put(url, headers=auth_headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Updated custom fields on {issue_key}")
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to update custom fields on {issue_key}: {e}")
|
||||
raise RuntimeError(f"Jira field update failed: {e}") from e
|
||||
|
||||
|
||||
def _transition_issue(
|
||||
jira_base_url: str,
|
||||
auth_headers: dict[str, str],
|
||||
issue_key: str,
|
||||
decision: str,
|
||||
writeback_config: dict,
|
||||
) -> None:
|
||||
"""POST a transition to move the issue to the appropriate column."""
|
||||
transition_map = writeback_config.get("transitions", {})
|
||||
transition_name = transition_map.get(decision)
|
||||
|
||||
if not transition_name:
|
||||
logger.debug(f"No transition configured for decision '{decision}', skipping")
|
||||
return
|
||||
|
||||
# First, get available transitions
|
||||
transitions_url = f"{jira_base_url}/rest/api/3/issue/{issue_key}/transitions"
|
||||
try:
|
||||
resp = requests.get(transitions_url, headers=auth_headers, timeout=30)
|
||||
resp.raise_for_status()
|
||||
available = resp.json().get("transitions", [])
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to fetch transitions for {issue_key}: {e}")
|
||||
raise RuntimeError(f"Jira transition fetch failed: {e}") from e
|
||||
|
||||
# Find the matching transition by name (case-insensitive)
|
||||
target_transition = None
|
||||
for t in available:
|
||||
if t.get("name", "").lower() == transition_name.lower():
|
||||
target_transition = t
|
||||
break
|
||||
|
||||
if not target_transition:
|
||||
available_names = [t.get("name", "") for t in available]
|
||||
logger.warning(
|
||||
f"Transition '{transition_name}' not found for {issue_key}. "
|
||||
f"Available: {available_names}"
|
||||
)
|
||||
return
|
||||
|
||||
# Perform the transition
|
||||
payload = {"transition": {"id": target_transition["id"]}}
|
||||
try:
|
||||
resp = requests.post(
|
||||
transitions_url, headers=auth_headers, json=payload, timeout=30
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Transitioned {issue_key} to '{transition_name}'")
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to transition {issue_key}: {e}")
|
||||
raise RuntimeError(f"Jira transition failed: {e}") from e
|
||||
|
||||
|
||||
def _post_comment(
|
||||
jira_base_url: str,
|
||||
auth_headers: dict[str, str],
|
||||
issue_key: str,
|
||||
proposal: ProposalReviewProposal,
|
||||
verdict_counts: dict[str, int],
|
||||
findings: list[ProposalReviewFinding],
|
||||
) -> None:
|
||||
"""POST a structured review summary as a Jira comment."""
|
||||
comment_text = _build_comment_text(proposal, verdict_counts, findings)
|
||||
|
||||
url = f"{jira_base_url}/rest/api/3/issue/{issue_key}/comment"
|
||||
# Jira Cloud uses ADF (Atlassian Document Format) for comments
|
||||
payload = {
|
||||
"body": {
|
||||
"version": 1,
|
||||
"type": "doc",
|
||||
"content": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comment_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(url, headers=auth_headers, json=payload, timeout=30)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Posted review summary comment on {issue_key}")
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to post comment on {issue_key}: {e}")
|
||||
raise RuntimeError(f"Jira comment post failed: {e}") from e
|
||||
|
||||
|
||||
def _build_comment_text(
|
||||
proposal: ProposalReviewProposal,
|
||||
verdict_counts: dict[str, int],
|
||||
findings: list[ProposalReviewFinding],
|
||||
) -> str:
|
||||
"""Build a structured review summary text for the Jira comment."""
|
||||
lines: list[str] = []
|
||||
|
||||
lines.append("=== Proposal Review Summary ===")
|
||||
lines.append("")
|
||||
|
||||
# Decision
|
||||
decision_text = proposal.status or "N/A"
|
||||
decision_notes = proposal.decision_notes
|
||||
lines.append(f"Final Decision: {decision_text}")
|
||||
if decision_notes:
|
||||
lines.append(f"Notes: {decision_notes}")
|
||||
lines.append("")
|
||||
|
||||
# Summary counts
|
||||
total = sum(verdict_counts.values())
|
||||
lines.append(f"Review Results ({total} rules evaluated):")
|
||||
lines.append(f" Pass: {verdict_counts.get('PASS', 0)}")
|
||||
lines.append(f" Fail: {verdict_counts.get('FAIL', 0)}")
|
||||
lines.append(f" Flag: {verdict_counts.get('FLAG', 0)}")
|
||||
lines.append(f" Needs Review: {verdict_counts.get('NEEDS_REVIEW', 0)}")
|
||||
lines.append(f" Not Applicable: {verdict_counts.get('NOT_APPLICABLE', 0)}")
|
||||
lines.append("")
|
||||
|
||||
# Individual findings (truncated for readability)
|
||||
if findings:
|
||||
lines.append("--- Detailed Findings ---")
|
||||
for f in findings:
|
||||
rule_name = f.rule.name if f.rule else "Unknown Rule"
|
||||
verdict = f.verdict or "N/A"
|
||||
officer_action = ""
|
||||
if f.decision_action:
|
||||
officer_action = f" | Officer: {f.decision_action}"
|
||||
lines.append(f" [{verdict}] {rule_name}{officer_action}")
|
||||
if f.explanation:
|
||||
# Truncate long explanations
|
||||
explanation = f.explanation[:200]
|
||||
if len(f.explanation) > 200:
|
||||
explanation += "..."
|
||||
lines.append(f" Reason: {explanation}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"Reviewed at: {datetime.now(timezone.utc).isoformat()}")
|
||||
lines.append("Generated by Onyx Proposal Review")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -1,532 +0,0 @@
|
||||
"""Proposal review engine — private helpers for task implementations.
|
||||
|
||||
The actual Celery @shared_task definitions live in tasks.py (for autodiscovery).
|
||||
This module contains the orchestration and evaluation logic they delegate to.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.server.features.proposal_review.engine.context_assembler import (
|
||||
ProposalContext,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _execute_review(
|
||||
review_run_id: str,
|
||||
rule_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Core review logic, separated for testability.
|
||||
|
||||
When rule_ids is None, evaluates all active rules in the run's ruleset
|
||||
(full run). When rule_ids is provided, deletes the old error findings
|
||||
for those rules and re-evaluates only them (retry flow).
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRun
|
||||
from onyx.server.features.proposal_review.engine.context_assembler import (
|
||||
get_proposal_context,
|
||||
)
|
||||
from onyx.server.features.proposal_review.engine.foa_fetcher import fetch_foa
|
||||
|
||||
run_uuid = UUID(review_run_id)
|
||||
is_retry = rule_ids is not None
|
||||
|
||||
if is_retry and not rule_ids:
|
||||
logger.warning(f"Retry called with empty rule_ids for run {review_run_id}")
|
||||
# Reset status since the API already set it to RUNNING
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(run_uuid, db_session)
|
||||
if run and run.status == "RUNNING":
|
||||
run.status = "COMPLETED"
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# Step 1: Set run status to RUNNING; for retries, clean up old findings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(run_uuid, db_session)
|
||||
if not run:
|
||||
raise ValueError(f"Review run {review_run_id} not found")
|
||||
|
||||
proposal_id = run.proposal_id
|
||||
ruleset_id = run.ruleset_id
|
||||
|
||||
if is_retry:
|
||||
rule_id_set = set(rule_ids)
|
||||
# Delete old error findings for the rules being retried
|
||||
failed = findings_db.get_failed_findings_for_run(run_uuid, db_session)
|
||||
failed_for_rules = [f for f in failed if str(f.rule_id) in rule_id_set]
|
||||
if failed_for_rules:
|
||||
findings_db.delete_findings(
|
||||
[f.id for f in failed_for_rules], db_session
|
||||
)
|
||||
# Roll back counters so re-evaluated rules are tracked correctly.
|
||||
# completed_rules is rolled back by the number of rules being
|
||||
# re-evaluated (not findings — a rule may lack a finding if
|
||||
# _save_error_finding itself failed). failed_rules is rolled back
|
||||
# by the number of error findings actually deleted.
|
||||
run.completed_rules = max(0, run.completed_rules - len(rule_ids))
|
||||
run.failed_rules = max(0, run.failed_rules - len(failed_for_rules))
|
||||
run.completed_at = None
|
||||
|
||||
run.status = "RUNNING"
|
||||
if not is_retry:
|
||||
run.started_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
# Step 2: Assemble proposal context
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
context = get_proposal_context(proposal_id, db_session)
|
||||
|
||||
# Step 3: Try to auto-fetch FOA if opportunity_id is in metadata
|
||||
opportunity_id = context.metadata.get("opportunity_id") or context.metadata.get(
|
||||
"funding_opportunity_number"
|
||||
)
|
||||
if opportunity_id and not context.foa_text:
|
||||
logger.info(f"Attempting to auto-fetch FOA for opportunity_id={opportunity_id}")
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
foa_text = fetch_foa(opportunity_id, proposal_id, db_session)
|
||||
db_session.commit()
|
||||
if foa_text:
|
||||
context.foa_text = foa_text
|
||||
logger.info(f"Auto-fetched FOA: {len(foa_text)} chars")
|
||||
except Exception as e:
|
||||
logger.warning(f"FOA auto-fetch failed (non-fatal): {e}")
|
||||
|
||||
# Step 4: Determine which rules to evaluate
|
||||
if is_retry:
|
||||
# Retry: use the specific rule IDs passed in
|
||||
rules_to_eval = rule_ids
|
||||
else:
|
||||
# Full run: get all active rules for the ruleset
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
rules = rulesets_db.list_rules_by_ruleset(
|
||||
ruleset_id, db_session, active_only=True
|
||||
)
|
||||
rules_to_eval = [str(rule.id) for rule in rules]
|
||||
|
||||
if not rules_to_eval:
|
||||
logger.warning(f"No active rules found for ruleset {ruleset_id}")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(run_uuid, db_session)
|
||||
if run:
|
||||
run.status = "COMPLETED"
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# Step 5: Update total_rules on the run (full run only)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(run_uuid, db_session)
|
||||
if run:
|
||||
run.total_rules = len(rules_to_eval)
|
||||
db_session.commit()
|
||||
|
||||
# Step 6: Evaluate rules in parallel via ThreadPoolExecutor
|
||||
parallel_workers = int(os.environ.get("PROPOSAL_REVIEW_PARALLEL_WORKERS", "4"))
|
||||
workers = min(parallel_workers, len(rules_to_eval))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_rule_id = {
|
||||
executor.submit(
|
||||
contextvars.copy_context().run,
|
||||
_evaluate_single_rule,
|
||||
review_run_id,
|
||||
rid,
|
||||
proposal_id,
|
||||
context,
|
||||
): rid
|
||||
for rid in rules_to_eval
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_rule_id):
|
||||
rid = future_to_rule_id[future]
|
||||
succeeded = True
|
||||
try:
|
||||
succeeded = future.result()
|
||||
except Exception as e:
|
||||
succeeded = False
|
||||
logger.error(
|
||||
f"Rule {rid} failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Increment completed_rules (and failed_rules on error) atomically
|
||||
# so the frontend progress bar always reaches 100%.
|
||||
updates: dict = {
|
||||
"completed_rules": ProposalReviewRun.completed_rules + 1,
|
||||
}
|
||||
if not succeeded:
|
||||
updates["failed_rules"] = ProposalReviewRun.failed_rules + 1
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.execute(
|
||||
update(ProposalReviewRun)
|
||||
.where(ProposalReviewRun.id == run_uuid)
|
||||
.values(**updates)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Step 7: Mark run as completed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(run_uuid, db_session)
|
||||
if run:
|
||||
run.status = "COMPLETED"
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Review run {review_run_id} completed: "
|
||||
f"{len(rules_to_eval)} rules evaluated"
|
||||
f"{' (retry)' if is_retry else ''}"
|
||||
)
|
||||
|
||||
|
||||
def _evaluate_and_save(
|
||||
review_run_id: str,
|
||||
rule_id: str,
|
||||
proposal_id: "UUID",
|
||||
context: "ProposalContext",
|
||||
) -> None:
|
||||
"""Evaluate a single rule and save the finding to DB."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
|
||||
from onyx.server.features.proposal_review.engine.rule_evaluator import (
|
||||
evaluate_rule,
|
||||
)
|
||||
|
||||
rule_uuid = UUID(rule_id)
|
||||
run_uuid = UUID(review_run_id)
|
||||
|
||||
# Load the rule from DB
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
rule = rulesets_db.get_rule(rule_uuid, db_session)
|
||||
if not rule:
|
||||
raise ValueError(f"Rule {rule_id} not found")
|
||||
|
||||
# Evaluate the rule
|
||||
result = evaluate_rule(rule, context, db_session)
|
||||
|
||||
# Save finding
|
||||
findings_db.create_finding(
|
||||
proposal_id=proposal_id,
|
||||
rule_id=rule_uuid,
|
||||
review_run_id=run_uuid,
|
||||
verdict=result["verdict"],
|
||||
confidence=result.get("confidence"),
|
||||
evidence=result.get("evidence"),
|
||||
explanation=result.get("explanation"),
|
||||
suggested_action=result.get("suggested_action"),
|
||||
llm_model=result.get("llm_model"),
|
||||
llm_tokens_used=result.get("llm_tokens_used"),
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
logger.debug(f"Rule {rule_id} evaluated: verdict={result['verdict']}")
|
||||
|
||||
|
||||
def _save_error_finding(
|
||||
review_run_id: str,
|
||||
rule_id: str,
|
||||
proposal_id: "UUID",
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Save an error finding when a rule evaluation fails."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
findings_db.create_finding(
|
||||
proposal_id=proposal_id,
|
||||
rule_id=UUID(rule_id),
|
||||
review_run_id=UUID(review_run_id),
|
||||
verdict="NEEDS_REVIEW",
|
||||
confidence="LOW",
|
||||
evidence=None,
|
||||
explanation=f"Rule evaluation failed with error: {error}",
|
||||
suggested_action="Manual review required due to system error.",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save error finding for rule {rule_id}: {e}")
|
||||
|
||||
|
||||
def _mark_run_failed(review_run_id: str) -> None:
|
||||
"""Mark a review run as FAILED."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import findings as findings_db
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
run = findings_db.get_review_run(UUID(review_run_id), db_session)
|
||||
if run:
|
||||
run.status = "FAILED"
|
||||
run.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark run {review_run_id} as FAILED: {e}")
|
||||
|
||||
|
||||
_MAX_RULE_RETRIES = int(os.environ.get("PROPOSAL_REVIEW_RULE_MAX_RETRIES", "2"))
|
||||
_RETRY_BACKOFF_BASE = 2 # seconds — retry waits 2s, 4s, ...
|
||||
|
||||
|
||||
def _evaluate_single_rule(
|
||||
review_run_id: str,
|
||||
rule_id: str,
|
||||
proposal_id: "UUID",
|
||||
context: "ProposalContext",
|
||||
) -> bool:
|
||||
"""Evaluate one rule, save the finding. Called from ThreadPoolExecutor.
|
||||
|
||||
Context is shared in-memory from the parent — no DB re-fetch needed.
|
||||
Retries up to _MAX_RULE_RETRIES times with exponential backoff on failure
|
||||
(e.g. LLM timeout). On final failure, an error finding (NEEDS_REVIEW) is
|
||||
saved so the officer sees which rule failed.
|
||||
|
||||
Returns True on success, False if all attempts failed.
|
||||
"""
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(_MAX_RULE_RETRIES + 1):
|
||||
try:
|
||||
_evaluate_and_save(review_run_id, rule_id, proposal_id, context)
|
||||
return True
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < _MAX_RULE_RETRIES:
|
||||
wait = _RETRY_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"Rule {rule_id} attempt {attempt + 1} failed: {e}. "
|
||||
f"Retrying in {wait}s..."
|
||||
)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
logger.error(
|
||||
f"Rule {rule_id} failed after {attempt + 1} attempts: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_save_error_finding(
|
||||
review_run_id=review_run_id,
|
||||
rule_id=rule_id,
|
||||
proposal_id=proposal_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _execute_checklist_import(import_job_id: str) -> None:
|
||||
"""Core import logic, separated for traceability."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.features.proposal_review.db import imports as imports_db
|
||||
from onyx.server.features.proposal_review.db import rulesets as rulesets_db
|
||||
from onyx.server.features.proposal_review.engine.checklist_importer import (
|
||||
decompose_checklist_item,
|
||||
enumerate_checklist_items,
|
||||
)
|
||||
|
||||
job_uuid = UUID(import_job_id)
|
||||
parallel_workers = int(
|
||||
os.environ.get("PROPOSAL_REVIEW_IMPORT_PARALLEL_WORKERS", "4")
|
||||
)
|
||||
|
||||
# Step 1: Mark RUNNING and load job data
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
job = imports_db.get_import_job(job_uuid, db_session)
|
||||
if not job:
|
||||
raise ValueError(f"Import job {import_job_id} not found")
|
||||
|
||||
job.status = "RUNNING"
|
||||
db_session.commit()
|
||||
|
||||
ruleset_id = job.ruleset_id
|
||||
extracted_text = job.extracted_text
|
||||
|
||||
llm = get_default_llm(timeout=300)
|
||||
|
||||
# Step 2: Enumerate checklist items
|
||||
items = enumerate_checklist_items(extracted_text, llm)
|
||||
|
||||
if not items:
|
||||
logger.warning(f"Import {import_job_id}: no checklist items found")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
job = imports_db.get_import_job(job_uuid, db_session)
|
||||
if job:
|
||||
job.status = "COMPLETED"
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# Split items with too many sub-checks into smaller pieces so each
|
||||
# LLM call produces bounded output. The threshold is conservative —
|
||||
# 3 sub-checks keeps output well within token limits.
|
||||
max_sub_checks = int(
|
||||
os.environ.get("PROPOSAL_REVIEW_IMPORT_MAX_SUB_CHECKS_PER_CALL", "3")
|
||||
)
|
||||
work_items = _split_large_items(items, max_sub_checks)
|
||||
|
||||
logger.info(
|
||||
f"Import {import_job_id}: enumerated {len(items)} items "
|
||||
f"({len(work_items)} work units after splitting), "
|
||||
f"decomposing with {parallel_workers} workers"
|
||||
)
|
||||
|
||||
# Step 3: Decompose each work item in parallel, persist as each completes
|
||||
rules_created = 0
|
||||
failed_items: list[str] = []
|
||||
workers = min(parallel_workers, len(work_items))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_item = {
|
||||
executor.submit(
|
||||
contextvars.copy_context().run,
|
||||
decompose_checklist_item,
|
||||
item,
|
||||
extracted_text,
|
||||
llm,
|
||||
): item
|
||||
for item in work_items
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_item):
|
||||
item = future_to_item[future]
|
||||
try:
|
||||
rule_dicts = future.result()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f" [{item.id}] '{item.name}' failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_items.append(item.name)
|
||||
continue
|
||||
|
||||
if not rule_dicts:
|
||||
continue
|
||||
|
||||
# Persist this item's rules in their own transaction
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for rd in rule_dicts:
|
||||
rule = rulesets_db.create_rule(
|
||||
ruleset_id=ruleset_id,
|
||||
name=rd["name"],
|
||||
description=rd.get("description"),
|
||||
category=rd.get("category"),
|
||||
rule_type=rd.get("rule_type", "CUSTOM_NL"),
|
||||
rule_intent=rd.get("rule_intent", "CHECK"),
|
||||
prompt_template=rd["prompt_template"],
|
||||
source="IMPORTED",
|
||||
is_hard_stop=False,
|
||||
priority=0,
|
||||
refinement_needed=rd.get("refinement_needed", False),
|
||||
refinement_question=rd.get("refinement_question"),
|
||||
db_session=db_session,
|
||||
)
|
||||
rule.is_active = False
|
||||
db_session.flush()
|
||||
|
||||
rules_created += len(rule_dicts)
|
||||
|
||||
# Update progress so the frontend can poll it
|
||||
job = imports_db.get_import_job(job_uuid, db_session)
|
||||
if job:
|
||||
job.rules_created = rules_created
|
||||
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f" [{item.id}] '{item.name}': "
|
||||
f"{len(rule_dicts)} rules persisted "
|
||||
f"({rules_created} total)"
|
||||
)
|
||||
|
||||
# Step 4: Mark completed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
job = imports_db.get_import_job(job_uuid, db_session)
|
||||
if job:
|
||||
job.status = "COMPLETED"
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
status = f"{rules_created} rules created"
|
||||
if failed_items:
|
||||
status += f", {len(failed_items)} items failed: {failed_items}"
|
||||
logger.info(f"Import job {import_job_id} completed: {status}")
|
||||
|
||||
|
||||
def _mark_import_failed(import_job_id: str, error: str) -> None:
|
||||
"""Mark an import job as FAILED."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import imports as imports_db
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
job = imports_db.get_import_job(UUID(import_job_id), db_session)
|
||||
if job:
|
||||
job.status = "FAILED"
|
||||
job.error_message = error
|
||||
job.completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark import job {import_job_id} as FAILED: {e}")
|
||||
|
||||
|
||||
def _split_large_items(
|
||||
items: list, # list[ChecklistItem] — untyped to avoid top-level import
|
||||
max_sub_checks: int,
|
||||
) -> list:
|
||||
"""Split checklist items with many sub-checks into smaller work units.
|
||||
|
||||
Each returned item has at most *max_sub_checks* sub-checks, keeping the
|
||||
LLM output bounded regardless of how large the original item was. Items
|
||||
that are already within the limit pass through unchanged.
|
||||
"""
|
||||
from onyx.server.features.proposal_review.engine.checklist_importer import (
|
||||
ChecklistItem,
|
||||
)
|
||||
|
||||
work_items: list[ChecklistItem] = []
|
||||
for item in items:
|
||||
if len(item.sub_checks) <= max_sub_checks:
|
||||
work_items.append(item)
|
||||
continue
|
||||
|
||||
# Split into batches, each becoming its own work unit
|
||||
for batch_idx in range(0, len(item.sub_checks), max_sub_checks):
|
||||
batch = item.sub_checks[batch_idx : batch_idx + max_sub_checks]
|
||||
part_num = (batch_idx // max_sub_checks) + 1
|
||||
work_items.append(
|
||||
ChecklistItem(
|
||||
id=f"{item.id}-p{part_num}",
|
||||
name=item.name,
|
||||
category=item.category,
|
||||
description=item.description,
|
||||
sub_checks=batch,
|
||||
)
|
||||
)
|
||||
|
||||
return work_items
|
||||
@@ -1,202 +0,0 @@
|
||||
"""Evaluates a single rule against a proposal context via LLM."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.server.features.proposal_review.db.models import ProposalReviewRule
|
||||
from onyx.server.features.proposal_review.engine.context_assembler import (
|
||||
ProposalContext,
|
||||
)
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are a meticulous grant proposal compliance reviewer for a university research office.
|
||||
Your role is to evaluate specific aspects of grant proposals against institutional
|
||||
and sponsor requirements.
|
||||
|
||||
You must evaluate each rule independently, focusing ONLY on the specific criterion
|
||||
described. Be precise in your assessment. When in doubt, mark for human review.
|
||||
|
||||
Always respond with a valid JSON object in the exact format specified."""
|
||||
|
||||
RESPONSE_FORMAT_INSTRUCTIONS = """
|
||||
Respond with ONLY a valid JSON object in the following format:
|
||||
{
|
||||
"verdict": "PASS | FAIL | FLAG | NEEDS_REVIEW | NOT_APPLICABLE",
|
||||
"confidence": "HIGH | MEDIUM | LOW",
|
||||
"evidence": "Direct quote or reference from the proposal documents that supports your verdict. If no relevant text found, state that clearly.",
|
||||
"explanation": "Concise reasoning for why this verdict was reached. Reference specific requirements and how the proposal does or does not meet them.",
|
||||
"suggested_action": "If verdict is FAIL or FLAG, describe what the officer or PI should do. Otherwise, null."
|
||||
}
|
||||
|
||||
Verdict meanings:
|
||||
- PASS: The proposal clearly meets this requirement.
|
||||
- FAIL: The proposal clearly does NOT meet this requirement.
|
||||
- FLAG: There is a potential issue that needs human attention.
|
||||
- NEEDS_REVIEW: Insufficient information to make a determination.
|
||||
- NOT_APPLICABLE: This rule does not apply to this proposal.
|
||||
"""
|
||||
|
||||
|
||||
def evaluate_rule(
|
||||
rule: ProposalReviewRule,
|
||||
context: ProposalContext,
|
||||
_db_session: Session | None = None,
|
||||
) -> dict:
|
||||
"""Evaluate one rule against proposal context via LLM.
|
||||
|
||||
1. Fills rule.prompt_template variables ({{proposal_text}}, {{metadata}}, etc.)
|
||||
2. Wraps in system prompt establishing reviewer role
|
||||
3. Calls llm.invoke() with structured output instructions
|
||||
4. Parses response into a findings dict
|
||||
|
||||
Args:
|
||||
rule: The rule to evaluate.
|
||||
context: Assembled proposal context.
|
||||
db_session: Optional DB session (not used for LLM call but kept for API compat).
|
||||
|
||||
Returns:
|
||||
Dict with verdict, confidence, evidence, explanation, suggested_action,
|
||||
plus llm_model and llm_tokens_used if available.
|
||||
"""
|
||||
# 1. Fill template variables
|
||||
filled_prompt = _fill_template(rule.prompt_template, context)
|
||||
|
||||
# 2. Build full prompt
|
||||
user_content = f"{filled_prompt}\n\n" f"{RESPONSE_FORMAT_INSTRUCTIONS}"
|
||||
|
||||
prompt_messages = [
|
||||
SystemMessage(content=SYSTEM_PROMPT),
|
||||
UserMessage(content=user_content),
|
||||
]
|
||||
|
||||
# 3. Call LLM — exceptions propagate to the caller so the retry
|
||||
# mechanism in _evaluate_single_rule can handle transient failures.
|
||||
llm = get_default_llm()
|
||||
with llm_generation_span(llm, "proposal_review", prompt_messages) as gen_span:
|
||||
response = llm.invoke(prompt_messages)
|
||||
record_llm_response(gen_span, response)
|
||||
raw_text = llm_response_to_string(response)
|
||||
|
||||
# Extract model info
|
||||
llm_model = llm.config.model_name if llm.config else None
|
||||
llm_tokens_used = _extract_token_usage(response)
|
||||
|
||||
# 4. Parse JSON response
|
||||
result = _parse_llm_response(raw_text)
|
||||
result["llm_model"] = llm_model
|
||||
result["llm_tokens_used"] = llm_tokens_used
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _fill_template(template: str, context: ProposalContext) -> str:
|
||||
"""Replace {{variable}} placeholders in the prompt template.
|
||||
|
||||
Supported variables:
|
||||
- {{proposal_text}} -> context.proposal_text
|
||||
- {{budget_text}} -> context.budget_text
|
||||
- {{foa_text}} -> context.foa_text
|
||||
- {{metadata}} -> JSON dump of context.metadata
|
||||
- {{metadata.FIELD}} -> specific metadata field value
|
||||
- {{jira_key}} -> context.jira_key
|
||||
"""
|
||||
result = template
|
||||
|
||||
# Direct substitutions
|
||||
result = result.replace("{{proposal_text}}", context.proposal_text or "")
|
||||
result = result.replace("{{budget_text}}", context.budget_text or "")
|
||||
result = result.replace("{{foa_text}}", context.foa_text or "")
|
||||
result = result.replace("{{jira_key}}", context.jira_key or "")
|
||||
|
||||
# Metadata as JSON
|
||||
metadata_str = json.dumps(context.metadata, indent=2, default=str)
|
||||
result = result.replace("{{metadata}}", metadata_str)
|
||||
|
||||
# Specific metadata fields: {{metadata.FIELD}}
|
||||
metadata_field_pattern = re.compile(r"\{\{metadata\.([^}]+)\}\}")
|
||||
for match in metadata_field_pattern.finditer(result):
|
||||
field_name = match.group(1)
|
||||
field_value = context.metadata.get(field_name, "")
|
||||
if isinstance(field_value, (dict, list)):
|
||||
field_value = json.dumps(field_value, default=str)
|
||||
result = result.replace(match.group(0), str(field_value))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_llm_response(raw_text: str) -> dict:
|
||||
"""Parse the LLM response text as JSON.
|
||||
|
||||
Handles cases where the LLM wraps JSON in markdown code fences.
|
||||
"""
|
||||
text = raw_text.strip()
|
||||
|
||||
# Strip markdown code fences if present
|
||||
if text.startswith("```"):
|
||||
# Remove opening fence (with optional language tag)
|
||||
text = re.sub(r"^```(?:json)?\s*\n?", "", text)
|
||||
# Remove closing fence
|
||||
text = re.sub(r"\n?```\s*$", "", text)
|
||||
text = text.strip()
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse LLM response as JSON: {text[:200]}...")
|
||||
return {
|
||||
"verdict": "NEEDS_REVIEW",
|
||||
"confidence": "LOW",
|
||||
"evidence": None,
|
||||
"explanation": f"Failed to parse LLM response. Raw output: {text[:500]}",
|
||||
"suggested_action": "Manual review required due to unparseable AI response.",
|
||||
}
|
||||
|
||||
# Validate and normalize the parsed result
|
||||
valid_verdicts = {"PASS", "FAIL", "FLAG", "NEEDS_REVIEW", "NOT_APPLICABLE"}
|
||||
valid_confidences = {"HIGH", "MEDIUM", "LOW"}
|
||||
|
||||
verdict = str(parsed.get("verdict", "NEEDS_REVIEW")).upper()
|
||||
if verdict not in valid_verdicts:
|
||||
verdict = "NEEDS_REVIEW"
|
||||
|
||||
confidence = str(parsed.get("confidence", "LOW")).upper()
|
||||
if confidence not in valid_confidences:
|
||||
confidence = "LOW"
|
||||
|
||||
return {
|
||||
"verdict": verdict,
|
||||
"confidence": confidence,
|
||||
"evidence": parsed.get("evidence"),
|
||||
"explanation": parsed.get("explanation"),
|
||||
"suggested_action": parsed.get("suggested_action"),
|
||||
}
|
||||
|
||||
|
||||
def _extract_token_usage(response: object) -> int | None:
|
||||
"""Best-effort extraction of token usage from the LLM response."""
|
||||
try:
|
||||
# litellm ModelResponse has a usage attribute
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
total = getattr(usage, "total_tokens", None)
|
||||
if total is not None:
|
||||
return int(total)
|
||||
# Sum prompt + completion tokens if total not available
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
if prompt_tokens or completion_tokens:
|
||||
return prompt_tokens + completion_tokens
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Celery tasks for proposal review — discovered by autodiscover_tasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="run_proposal_review",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
soft_time_limit=3600,
|
||||
time_limit=3660,
|
||||
)
|
||||
def run_proposal_review(
|
||||
_self: object,
|
||||
review_run_id: str,
|
||||
tenant_id: str,
|
||||
rule_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Evaluate rules for a review run.
|
||||
|
||||
When rule_ids is None, evaluates all active rules in the run's ruleset
|
||||
(full run). When rule_ids is provided, evaluates only those specific
|
||||
rules (retry flow).
|
||||
"""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
try:
|
||||
from onyx.tracing.framework.create import trace
|
||||
|
||||
with trace(
|
||||
"proposal_review",
|
||||
metadata={"review_run_id": review_run_id},
|
||||
):
|
||||
from onyx.server.features.proposal_review.engine.review_engine import (
|
||||
_execute_review,
|
||||
)
|
||||
|
||||
_execute_review(review_run_id, rule_ids=rule_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Review run {review_run_id} failed: {e}", exc_info=True)
|
||||
from onyx.server.features.proposal_review.engine.review_engine import (
|
||||
_mark_run_failed,
|
||||
)
|
||||
|
||||
_mark_run_failed(review_run_id)
|
||||
raise
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@shared_task(name="run_checklist_import", bind=True, ignore_result=True)
|
||||
def run_checklist_import(_self: object, import_job_id: str, tenant_id: str) -> None:
|
||||
"""Background task: decompose a checklist via LLM and save rules."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
try:
|
||||
from onyx.tracing.framework.create import trace
|
||||
|
||||
with trace(
|
||||
"checklist_import",
|
||||
metadata={"import_job_id": import_job_id},
|
||||
):
|
||||
from onyx.server.features.proposal_review.engine.review_engine import (
|
||||
_execute_checklist_import,
|
||||
)
|
||||
|
||||
_execute_checklist_import(import_job_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Import job {import_job_id} failed: {e}", exc_info=True)
|
||||
from onyx.server.features.proposal_review.engine.review_engine import (
|
||||
_mark_import_failed,
|
||||
)
|
||||
|
||||
_mark_import_failed(import_job_id, str(e))
|
||||
raise
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="sync_decision_to_jira",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
soft_time_limit=60,
|
||||
time_limit=90,
|
||||
)
|
||||
def sync_decision_to_jira(_self: object, proposal_id: str, tenant_id: str) -> None:
|
||||
"""Writes officer decision back to Jira.
|
||||
|
||||
Dispatched from the sync-jira API endpoint.
|
||||
"""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.engine.jira_sync import sync_to_jira
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_to_jira(UUID(proposal_id), db_session)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(f"Jira sync completed for proposal {proposal_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Jira sync failed for proposal {proposal_id}: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DANGLING_IMPORT_JOBS,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
soft_time_limit=60,
|
||||
time_limit=90,
|
||||
)
|
||||
def check_for_dangling_import_jobs(_self: object, *, tenant_id: str) -> None:
|
||||
"""Beat task: mark import jobs stuck in PENDING/RUNNING as FAILED.
|
||||
|
||||
A job is considered stuck if it has been in a non-terminal state for
|
||||
longer than the stale threshold (default 60 minutes). This handles
|
||||
cases where the Celery message was discarded (e.g. worker restart
|
||||
before the task was registered) or the task crashed without marking
|
||||
the job as FAILED.
|
||||
"""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.features.proposal_review.db import imports as imports_db
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_DANGLING_IMPORT_JOBS_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
logger.info(
|
||||
f"check_for_dangling_import_jobs - Lock not acquired: tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
dangling = imports_db.get_dangling_import_jobs(
|
||||
db_session, stale_threshold_minutes=60
|
||||
)
|
||||
if not dangling:
|
||||
return
|
||||
|
||||
for job in dangling:
|
||||
logger.warning(
|
||||
f"Marking dangling import job {job.id} as FAILED "
|
||||
f"(status={job.status}, created_at={job.created_at})"
|
||||
)
|
||||
imports_db.mark_import_job_failed(
|
||||
job,
|
||||
"Import timed out — the background task did not complete. "
|
||||
"Please try importing again.",
|
||||
db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Cleaned up {len(dangling)} dangling import job(s) "
|
||||
f"for tenant {tenant_id}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during dangling import job cleanup")
|
||||
finally:
|
||||
if locked:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
logger.error(
|
||||
f"check_for_dangling_import_jobs - "
|
||||
f"Lock not owned on completion: tenant={tenant_id}"
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, queue wait time histograms,
|
||||
and retry/reject/revoke counts.
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
@@ -72,32 +71,6 @@ TASK_REJECTED = Counter(
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_QUEUE_WAIT = Histogram(
|
||||
"onyx_celery_task_queue_wait_seconds",
|
||||
"Time a Celery task spent waiting in the queue before execution started",
|
||||
["task_name", "queue"],
|
||||
buckets=[
|
||||
0.1,
|
||||
0.5,
|
||||
1,
|
||||
5,
|
||||
30,
|
||||
60,
|
||||
300,
|
||||
600,
|
||||
1800,
|
||||
3600,
|
||||
7200,
|
||||
14400,
|
||||
28800,
|
||||
43200,
|
||||
86400,
|
||||
172800,
|
||||
432000,
|
||||
864000,
|
||||
],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
@@ -160,13 +133,6 @@ def on_celery_task_prerun(
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
|
||||
headers = getattr(task.request, "headers", None) or {}
|
||||
enqueued_at = headers.get("enqueued_at")
|
||||
if isinstance(enqueued_at, (int, float)):
|
||||
TASK_QUEUE_WAIT.labels(**labels).observe(
|
||||
max(0.0, time.time() - enqueued_at)
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Connector-deletion-specific Prometheus metrics.
|
||||
|
||||
Tracks the deletion lifecycle:
|
||||
1. Deletions started (taskset generated)
|
||||
2. Deletions completed (success or failure)
|
||||
3. Taskset duration (from taskset generation to completion or failure).
|
||||
Note: this measures the most recent taskset execution, NOT wall-clock
|
||||
time since the user triggered the deletion. When deletion is blocked by
|
||||
indexing/pruning/permissions, the fence is cleared and a fresh taskset
|
||||
is generated on each retry, resetting this timer.
|
||||
4. Deletion blocked by dependencies (indexing, pruning, permissions, etc.)
|
||||
5. Fence resets (stuck deletion recovery)
|
||||
|
||||
All metrics are labeled by tenant_id. cc_pair_id is intentionally excluded
|
||||
to avoid unbounded cardinality.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.deletion_metrics import (
|
||||
inc_deletion_started,
|
||||
inc_deletion_completed,
|
||||
observe_deletion_taskset_duration,
|
||||
inc_deletion_blocked,
|
||||
inc_deletion_fence_reset,
|
||||
)
|
||||
"""
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DELETION_STARTED = Counter(
|
||||
"onyx_deletion_started_total",
|
||||
"Connector deletions initiated (taskset generated)",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
DELETION_COMPLETED = Counter(
|
||||
"onyx_deletion_completed_total",
|
||||
"Connector deletions completed",
|
||||
["tenant_id", "outcome"],
|
||||
)
|
||||
|
||||
DELETION_TASKSET_DURATION = Histogram(
|
||||
"onyx_deletion_taskset_duration_seconds",
|
||||
"Duration of a connector deletion taskset, from taskset generation "
|
||||
"to completion or failure. Does not include time spent blocked on "
|
||||
"indexing/pruning/permissions before the taskset was generated.",
|
||||
["tenant_id", "outcome"],
|
||||
buckets=[10, 30, 60, 120, 300, 600, 1800, 3600, 7200, 21600],
|
||||
)
|
||||
|
||||
DELETION_BLOCKED = Counter(
|
||||
"onyx_deletion_blocked_total",
|
||||
"Times deletion was blocked by a dependency",
|
||||
["tenant_id", "blocker"],
|
||||
)
|
||||
|
||||
DELETION_FENCE_RESET = Counter(
|
||||
"onyx_deletion_fence_reset_total",
|
||||
"Deletion fences reset due to missing celery tasks",
|
||||
["tenant_id"],
|
||||
)
|
||||
|
||||
|
||||
def inc_deletion_started(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_STARTED.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion started", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_completed(tenant_id: str, outcome: str) -> None:
|
||||
try:
|
||||
DELETION_COMPLETED.labels(tenant_id=tenant_id, outcome=outcome).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion completed", exc_info=True)
|
||||
|
||||
|
||||
def observe_deletion_taskset_duration(
|
||||
tenant_id: str, outcome: str, duration_seconds: float
|
||||
) -> None:
|
||||
try:
|
||||
DELETION_TASKSET_DURATION.labels(tenant_id=tenant_id, outcome=outcome).observe(
|
||||
duration_seconds
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion taskset duration", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_blocked(tenant_id: str, blocker: str) -> None:
|
||||
try:
|
||||
DELETION_BLOCKED.labels(tenant_id=tenant_id, blocker=blocker).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion blocked", exc_info=True)
|
||||
|
||||
|
||||
def inc_deletion_fence_reset(tenant_id: str) -> None:
|
||||
try:
|
||||
DELETION_FENCE_RESET.labels(tenant_id=tenant_id).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record deletion fence reset", exc_info=True)
|
||||
@@ -27,7 +27,6 @@ _DEFAULT_PORTS: dict[str, int] = {
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
"heavy": 9094,
|
||||
"light": 9095,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
|
||||
@@ -28,14 +28,14 @@ PRUNING_ENUMERATION_DURATION = Histogram(
|
||||
"onyx_pruning_enumeration_duration_seconds",
|
||||
"Duration of document ID enumeration from the source connector during pruning",
|
||||
["connector_type"],
|
||||
buckets=[5, 60, 600, 1800, 3600, 10800, 21600],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
PRUNING_DIFF_DURATION = Histogram(
|
||||
"onyx_pruning_diff_duration_seconds",
|
||||
"Duration of diff computation and subtask dispatch during pruning",
|
||||
["connector_type"],
|
||||
buckets=[0.1, 0.25, 0.5, 1, 2, 5, 15, 30, 60],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
PRUNING_RATE_LIMIT_ERRORS = Counter(
|
||||
|
||||
@@ -65,8 +65,8 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = True
|
||||
search_ui_enabled: bool | None = True
|
||||
multi_model_chat_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
@@ -11,8 +12,6 @@ from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.db.usage import check_usage_limit
|
||||
from onyx.db.usage import UsageLimitExceededError
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitKeys
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -256,14 +255,11 @@ def check_usage_and_raise(
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
elif usage_type == UsageType.API_CALLS:
|
||||
if is_trial and e.limit == 0:
|
||||
detail = "API access is not available on trial accounts. Please upgrade to a paid plan to use the API and chat widget."
|
||||
else:
|
||||
detail = (
|
||||
f"API call limit exceeded for {user_type} account. "
|
||||
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
detail = (
|
||||
f"API call limit exceeded for {user_type} account. "
|
||||
f"Calls: {int(e.current)}, Limit: {int(e.limit)} per week. "
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
else:
|
||||
detail = (
|
||||
f"Non-streaming API call limit exceeded for {user_type} account. "
|
||||
@@ -271,4 +267,4 @@ def check_usage_and_raise(
|
||||
"Please upgrade your plan or wait for the next billing period."
|
||||
)
|
||||
|
||||
raise OnyxError(OnyxErrorCode.RATE_LIMITED, detail)
|
||||
raise HTTPException(status_code=429, detail=detail)
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -114,10 +113,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -132,33 +131,6 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -17,7 +17,6 @@ def documents_to_indexing_documents(
|
||||
processed_sections = []
|
||||
for section in document.sections:
|
||||
processed_section = Section(
|
||||
type=section.type,
|
||||
text=section.text or "",
|
||||
link=section.link,
|
||||
image_file_id=None,
|
||||
|
||||
@@ -26,7 +26,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.18.4
|
||||
alembic==1.10.4
|
||||
amqp==5.3.1
|
||||
# via kombu
|
||||
annotated-doc==0.0.4
|
||||
@@ -174,7 +174,7 @@ coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
courlan==1.3.2
|
||||
# via trafilatura
|
||||
cryptography==46.0.7
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# authlib
|
||||
# google-auth
|
||||
@@ -408,7 +408,7 @@ kombu==5.5.4
|
||||
# via celery
|
||||
kubernetes==31.0.0
|
||||
# via onyx
|
||||
langchain-core==1.2.28
|
||||
langchain-core==1.2.22
|
||||
langdetect==1.0.9
|
||||
# via unstructured
|
||||
langfuse==3.10.0
|
||||
@@ -583,7 +583,7 @@ pathable==0.4.4
|
||||
# via jsonschema-path
|
||||
pdfminer-six==20251107
|
||||
# via markitdown
|
||||
pillow==12.2.0
|
||||
pillow==12.1.1
|
||||
# via python-pptx
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
@@ -666,9 +666,7 @@ pyee==13.0.0
|
||||
# via playwright
|
||||
pygithub==2.5.0
|
||||
pygments==2.20.0
|
||||
# via
|
||||
# pytest
|
||||
# rich
|
||||
# via rich
|
||||
pyjwt==2.12.0
|
||||
# via
|
||||
# fastapi-users
|
||||
@@ -682,13 +680,13 @@ pynacl==1.6.2
|
||||
pypandoc-binary==1.16.2
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.10.0
|
||||
pypdf==6.9.2
|
||||
# via unstructured-client
|
||||
pyperclip==1.11.0
|
||||
# via fastmcp
|
||||
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||
# via humanfriendly
|
||||
pytest==9.0.3
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# pytest-base-url
|
||||
# pytest-mock
|
||||
@@ -696,7 +694,7 @@ pytest==9.0.3
|
||||
pytest-base-url==2.1.0
|
||||
# via pytest-playwright
|
||||
pytest-mock==3.12.0
|
||||
pytest-playwright==0.7.2
|
||||
pytest-playwright==0.7.0
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
|
||||
@@ -22,7 +22,7 @@ aiolimiter==1.2.1
|
||||
# via voyageai
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
alembic==1.18.4
|
||||
alembic==1.10.4
|
||||
# via pytest-alembic
|
||||
annotated-doc==0.0.4
|
||||
# via fastapi
|
||||
@@ -46,7 +46,7 @@ attrs==25.4.0
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
black==26.3.1
|
||||
black==25.1.0
|
||||
boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -95,7 +95,7 @@ comm==0.2.3
|
||||
# via ipykernel
|
||||
contourpy==1.3.3
|
||||
# via matplotlib
|
||||
cryptography==46.0.7
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
@@ -254,7 +254,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.5
|
||||
onyx-devtools==0.7.3
|
||||
openai==2.14.0
|
||||
# via
|
||||
# litellm
|
||||
@@ -274,13 +274,13 @@ parameterized==0.9.0
|
||||
# via cohere
|
||||
parso==0.8.5
|
||||
# via jedi
|
||||
pathspec==1.0.4
|
||||
pathspec==0.12.1
|
||||
# via
|
||||
# black
|
||||
# hatchling
|
||||
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
# via ipython
|
||||
pillow==12.2.0
|
||||
pillow==12.1.1
|
||||
# via matplotlib
|
||||
platformdirs==4.5.0
|
||||
# via
|
||||
@@ -339,12 +339,11 @@ pygments==2.20.0
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
# pytest
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
pytest==9.0.3
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# pytest-alembic
|
||||
# pytest-asyncio
|
||||
@@ -370,8 +369,6 @@ python-dotenv==1.1.1
|
||||
# pytest-dotenv
|
||||
python-multipart==0.0.22
|
||||
# via mcp
|
||||
pytokens==0.4.1
|
||||
# via black
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
pyyaml==6.0.3
|
||||
|
||||
@@ -76,7 +76,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.7
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
|
||||
@@ -91,7 +91,7 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# tqdm
|
||||
cryptography==46.0.7
|
||||
cryptography==46.0.6
|
||||
# via
|
||||
# google-auth
|
||||
# pyjwt
|
||||
@@ -264,7 +264,7 @@ packaging==24.2
|
||||
# transformers
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
pillow==12.2.0
|
||||
pillow==12.1.1
|
||||
# via sentence-transformers
|
||||
prometheus-client==0.23.1
|
||||
# via
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Tests for GoogleDriveConnector.resolve_errors against real Google Drive."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ALL_EXPECTED_HIERARCHY_NODES,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
|
||||
_DRIVE_ID_MAPPING_PATH = os.path.join(
|
||||
os.path.dirname(__file__), "drive_id_mapping.json"
|
||||
)
|
||||
|
||||
|
||||
def _load_web_view_links(file_ids: list[int]) -> list[str]:
|
||||
with open(_DRIVE_ID_MAPPING_PATH) as f:
|
||||
mapping: dict[str, str] = json.load(f)
|
||||
return [mapping[str(fid)] for fid in file_ids]
|
||||
|
||||
|
||||
def _build_failures(web_view_links: list[str]) -> list[ConnectorFailure]:
|
||||
return [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=link,
|
||||
document_link=link,
|
||||
),
|
||||
failure_message=f"Synthetic failure for {link}",
|
||||
)
|
||||
for link in web_view_links
|
||||
]
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_single_file(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve a single known file and verify we get back exactly one Document."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
web_view_links = _load_web_view_links([0])
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert len(new_failures) == 0
|
||||
assert docs[0].semantic_identifier == "file_0.txt"
|
||||
|
||||
# Should yield at least one hierarchy node (the file's parent folder chain)
|
||||
assert len(hierarchy_nodes) > 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_multiple_files(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve multiple files across different folders via batch API."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
# Pick files from different folders: admin files (0-4), shared drive 1 (20-24), folder_2 (45-49)
|
||||
file_ids = [0, 1, 20, 21, 45]
|
||||
web_view_links = _load_web_view_links(file_ids)
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
|
||||
assert len(new_failures) == 0
|
||||
retrieved_names = {doc.semantic_identifier for doc in docs}
|
||||
expected_names = {f"file_{fid}.txt" for fid in file_ids}
|
||||
assert expected_names == retrieved_names
|
||||
|
||||
# Files span multiple folders, so we should get hierarchy nodes
|
||||
assert len(hierarchy_nodes) > 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_hierarchy_nodes_are_valid(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Verify that hierarchy nodes from resolve_errors match expected structure."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
# File in folder_1 (inside shared_drive_1) — should walk up to shared_drive_1 root
|
||||
web_view_links = _load_web_view_links([25])
|
||||
failures = _build_failures(web_view_links)
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
hierarchy_nodes = [r for r in results if isinstance(r, HierarchyNode)]
|
||||
node_ids = {node.raw_node_id for node in hierarchy_nodes}
|
||||
|
||||
# File 25 is in folder_1 which is inside shared_drive_1.
|
||||
# The parent walk must yield at least these two ancestors.
|
||||
assert (
|
||||
FOLDER_1_ID in node_ids
|
||||
), f"Expected folder_1 ({FOLDER_1_ID}) in hierarchy nodes, got: {node_ids}"
|
||||
assert (
|
||||
SHARED_DRIVE_1_ID in node_ids
|
||||
), f"Expected shared_drive_1 ({SHARED_DRIVE_1_ID}) in hierarchy nodes, got: {node_ids}"
|
||||
|
||||
for node in hierarchy_nodes:
|
||||
if node.raw_node_id not in ALL_EXPECTED_HIERARCHY_NODES:
|
||||
continue
|
||||
expected = ALL_EXPECTED_HIERARCHY_NODES[node.raw_node_id]
|
||||
assert node.display_name == expected.display_name, (
|
||||
f"Display name mismatch for {node.raw_node_id}: "
|
||||
f"expected '{expected.display_name}', got '{node.display_name}'"
|
||||
)
|
||||
assert node.node_type == expected.node_type, (
|
||||
f"Node type mismatch for {node.raw_node_id}: "
|
||||
f"expected '{expected.node_type}', got '{node.node_type}'"
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_with_invalid_link(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolve with a mix of valid and invalid links — invalid ones yield ConnectorFailure."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
valid_links = _load_web_view_links([0])
|
||||
invalid_link = "https://drive.google.com/file/d/NONEXISTENT_FILE_ID_12345"
|
||||
failures = _build_failures(valid_links + [invalid_link])
|
||||
|
||||
results = list(connector.resolve_errors(failures))
|
||||
|
||||
docs = [r for r in results if isinstance(r, Document)]
|
||||
new_failures = [r for r in results if isinstance(r, ConnectorFailure)]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].semantic_identifier == "file_0.txt"
|
||||
assert len(new_failures) == 1
|
||||
assert new_failures[0].failed_document is not None
|
||||
assert new_failures[0].failed_document.document_id == invalid_link
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_empty_errors(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Resolving an empty error list should yield nothing."""
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
results = list(connector.resolve_errors([]))
|
||||
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
@patch("onyx.file_processing.extract_file_text.get_unstructured_api_key")
|
||||
def test_resolve_entity_failures_are_skipped(
|
||||
mock_api_key: None, # noqa: ARG001
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
"""Entity failures (not document failures) should be skipped by resolve_errors."""
|
||||
from onyx.connectors.models import EntityFailure
|
||||
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=False,
|
||||
)
|
||||
|
||||
entity_failure = ConnectorFailure(
|
||||
failed_entity=EntityFailure(entity_id="some_stage"),
|
||||
failure_message="retrieval failure",
|
||||
)
|
||||
|
||||
results = list(connector.resolve_errors([entity_failure]))
|
||||
|
||||
assert len(results) == 0
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Shared fixtures for proposal review integration tests.
|
||||
|
||||
Uses the same real-PostgreSQL pattern as the parent external_dependency_unit
|
||||
conftest. Tables must already exist (via the 61ea78857c97 migration).
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# Tables to clean up after each test, in dependency order (children first).
|
||||
_PROPOSAL_REVIEW_TABLES = [
|
||||
"proposal_review_finding",
|
||||
"proposal_review_run",
|
||||
"proposal_review_document",
|
||||
"proposal_review_proposal",
|
||||
"proposal_review_rule",
|
||||
"proposal_review_import_job",
|
||||
"proposal_review_ruleset",
|
||||
"proposal_review_config",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session(tenant_context: None) -> Generator[Session, None, None]: # noqa: ARG001
|
||||
"""Yield a DB session scoped to the current tenant.
|
||||
|
||||
After the test completes, all proposal_review rows are deleted so tests
|
||||
don't leave artifacts in the database.
|
||||
"""
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
with get_session_with_current_tenant() as session:
|
||||
yield session
|
||||
|
||||
# Clean up all proposal_review data created during this test
|
||||
try:
|
||||
for table in _PROPOSAL_REVIEW_TABLES:
|
||||
session.execute(text(f"DELETE FROM {table}")) # noqa: S608
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_user(db_session: Session) -> User:
|
||||
"""Create a throwaway user for FK references (triggered_by, officer_id, etc.)."""
|
||||
unique_email = f"pr_test_{uuid4().hex[:8]}@example.com"
|
||||
password_helper = PasswordHelper()
|
||||
hashed_password = password_helper.hash(password_helper.generate())
|
||||
|
||||
user = User(
|
||||
id=uuid4(),
|
||||
email=unique_email,
|
||||
hashed_password=hashed_password,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
db_session.refresh(user)
|
||||
return user
|
||||
@@ -1,328 +0,0 @@
|
||||
"""Integration tests for per-finding decisions, proposal decisions, and config."""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.proposal_review.db.config import get_config
|
||||
from onyx.server.features.proposal_review.db.config import upsert_config
|
||||
from onyx.server.features.proposal_review.db.decisions import (
|
||||
mark_proposal_jira_synced,
|
||||
)
|
||||
from onyx.server.features.proposal_review.db.decisions import (
|
||||
update_proposal_decision,
|
||||
)
|
||||
from onyx.server.features.proposal_review.db.decisions import (
|
||||
upsert_finding_decision,
|
||||
)
|
||||
from onyx.server.features.proposal_review.db.findings import create_finding
|
||||
from onyx.server.features.proposal_review.db.findings import create_review_run
|
||||
from onyx.server.features.proposal_review.db.findings import get_finding
|
||||
from onyx.server.features.proposal_review.db.proposals import get_or_create_proposal
|
||||
from onyx.server.features.proposal_review.db.rulesets import create_rule
|
||||
from onyx.server.features.proposal_review.db.rulesets import create_ruleset
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
TENANT = TEST_TENANT_ID
|
||||
|
||||
|
||||
def _make_finding(db_session: Session, test_user: User):
|
||||
"""Helper: create a full chain (ruleset -> rule -> proposal -> run -> finding)."""
|
||||
rs = create_ruleset(
|
||||
tenant_id=TENANT,
|
||||
name=f"RS-{uuid4().hex[:6]}",
|
||||
db_session=db_session,
|
||||
created_by=test_user.id,
|
||||
)
|
||||
rule = create_rule(
|
||||
ruleset_id=rs.id,
|
||||
name="Test Rule",
|
||||
rule_type="DOCUMENT_CHECK",
|
||||
prompt_template="{{proposal_text}}",
|
||||
db_session=db_session,
|
||||
)
|
||||
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
|
||||
run = create_review_run(
|
||||
proposal_id=proposal.id,
|
||||
ruleset_id=rs.id,
|
||||
triggered_by=test_user.id,
|
||||
total_rules=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
finding = create_finding(
|
||||
proposal_id=proposal.id,
|
||||
rule_id=rule.id,
|
||||
review_run_id=run.id,
|
||||
verdict="FAIL",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return finding, proposal
|
||||
|
||||
|
||||
class TestFindingDecision:
|
||||
def test_create_finding_decision(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
finding, _ = _make_finding(db_session, test_user)
|
||||
|
||||
updated = upsert_finding_decision(
|
||||
finding_id=finding.id,
|
||||
officer_id=test_user.id,
|
||||
action="VERIFIED",
|
||||
db_session=db_session,
|
||||
notes="Looks good",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.id == finding.id
|
||||
assert updated.decision_action == "VERIFIED"
|
||||
assert updated.decision_notes == "Looks good"
|
||||
assert updated.decided_at is not None
|
||||
|
||||
def test_upsert_overwrites_previous_decision(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
finding, _ = _make_finding(db_session, test_user)
|
||||
|
||||
upsert_finding_decision(
|
||||
finding_id=finding.id,
|
||||
officer_id=test_user.id,
|
||||
action="VERIFIED",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
updated = upsert_finding_decision(
|
||||
finding_id=finding.id,
|
||||
officer_id=test_user.id,
|
||||
action="ISSUE",
|
||||
db_session=db_session,
|
||||
notes="Actually, this is a problem",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Same row was updated
|
||||
assert updated.id == finding.id
|
||||
assert updated.decision_action == "ISSUE"
|
||||
assert updated.decision_notes == "Actually, this is a problem"
|
||||
|
||||
def test_finding_decision_accessible_via_finding(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
finding, _ = _make_finding(db_session, test_user)
|
||||
|
||||
upsert_finding_decision(
|
||||
finding_id=finding.id,
|
||||
officer_id=test_user.id,
|
||||
action="OVERRIDDEN",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
fetched = get_finding(finding.id, db_session)
|
||||
assert fetched is not None
|
||||
assert fetched.decision_action == "OVERRIDDEN"
|
||||
assert fetched.decision_officer_id == test_user.id
|
||||
|
||||
def test_finding_has_no_decision_by_default(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
finding, _ = _make_finding(db_session, test_user)
|
||||
assert finding.decision_action is None
|
||||
assert finding.decided_at is None
|
||||
|
||||
|
||||
class TestProposalDecision:
|
||||
def test_update_proposal_decision(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
|
||||
db_session.commit()
|
||||
|
||||
updated = update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="APPROVED",
|
||||
db_session=db_session,
|
||||
notes="All checks pass",
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.status == "APPROVED"
|
||||
assert updated.decision_notes == "All checks pass"
|
||||
assert updated.decision_officer_id == test_user.id
|
||||
assert updated.decision_at is not None
|
||||
assert updated.jira_synced is False
|
||||
|
||||
def test_proposal_decision_overwrites_previous(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
|
||||
db_session.commit()
|
||||
|
||||
update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="CHANGES_REQUESTED",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
updated = update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="APPROVED",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.status == "APPROVED"
|
||||
|
||||
def test_mark_proposal_jira_synced(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
|
||||
update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="APPROVED",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert proposal.jira_synced is False
|
||||
|
||||
synced = mark_proposal_jira_synced(proposal.id, TENANT, db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert synced is not None
|
||||
assert synced.jira_synced is True
|
||||
assert synced.jira_synced_at is not None
|
||||
|
||||
def test_mark_jira_synced_returns_none_for_nonexistent(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
assert mark_proposal_jira_synced(uuid4(), TENANT, db_session) is None
|
||||
|
||||
def test_new_decision_resets_jira_synced(
|
||||
self, db_session: Session, test_user: User
|
||||
) -> None:
|
||||
"""Re-deciding should reset jira_synced so the new decision can be synced."""
|
||||
proposal = get_or_create_proposal(f"doc-{uuid4().hex[:8]}", TENANT, db_session)
|
||||
update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="APPROVED",
|
||||
db_session=db_session,
|
||||
)
|
||||
mark_proposal_jira_synced(proposal.id, TENANT, db_session)
|
||||
db_session.commit()
|
||||
assert proposal.jira_synced is True
|
||||
|
||||
update_proposal_decision(
|
||||
proposal_id=proposal.id,
|
||||
tenant_id=TENANT,
|
||||
officer_id=test_user.id,
|
||||
decision="REJECTED",
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.refresh(proposal)
|
||||
assert proposal.jira_synced is False
|
||||
assert proposal.jira_synced_at is None
|
||||
assert proposal.status == "REJECTED"
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def test_create_config(self, db_session: Session) -> None:
|
||||
# Use a unique tenant to avoid collision with other tests
|
||||
tenant = f"test-tenant-{uuid4().hex[:8]}"
|
||||
config = upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
jira_project_key="PROJ",
|
||||
field_mapping={"title": "summary", "budget": "customfield_10001"},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert config.id is not None
|
||||
assert config.tenant_id == tenant
|
||||
assert config.jira_project_key == "PROJ"
|
||||
assert config.field_mapping == {
|
||||
"title": "summary",
|
||||
"budget": "customfield_10001",
|
||||
}
|
||||
|
||||
def test_upsert_config_updates_existing(self, db_session: Session) -> None:
|
||||
tenant = f"test-tenant-{uuid4().hex[:8]}"
|
||||
first = upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
jira_project_key="OLD",
|
||||
)
|
||||
db_session.commit()
|
||||
first_id = first.id
|
||||
|
||||
second = upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
jira_project_key="NEW",
|
||||
field_mapping={"x": "y"},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert second.id == first_id
|
||||
assert second.jira_project_key == "NEW"
|
||||
assert second.field_mapping == {"x": "y"}
|
||||
|
||||
def test_get_config_returns_correct_tenant(self, db_session: Session) -> None:
|
||||
tenant = f"test-tenant-{uuid4().hex[:8]}"
|
||||
upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
jira_project_key="ABC",
|
||||
jira_writeback={"status_field": "customfield_20001"},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
fetched = get_config(tenant, db_session)
|
||||
assert fetched is not None
|
||||
assert fetched.jira_project_key == "ABC"
|
||||
assert fetched.jira_writeback == {"status_field": "customfield_20001"}
|
||||
|
||||
def test_get_config_returns_none_for_unknown_tenant(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
assert get_config(f"nonexistent-{uuid4().hex[:8]}", db_session) is None
|
||||
|
||||
def test_upsert_config_preserves_unset_fields(self, db_session: Session) -> None:
|
||||
tenant = f"test-tenant-{uuid4().hex[:8]}"
|
||||
upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
jira_project_key="KEEP",
|
||||
jira_connector_id=42,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Update only field_mapping, leave jira_project_key alone
|
||||
upsert_config(
|
||||
tenant_id=tenant,
|
||||
db_session=db_session,
|
||||
field_mapping={"a": "b"},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
fetched = get_config(tenant, db_session)
|
||||
assert fetched is not None
|
||||
assert fetched.jira_project_key == "KEEP"
|
||||
assert fetched.jira_connector_id == 42
|
||||
assert fetched.field_mapping == {"a": "b"}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user