mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-17 07:26:45 +00:00
Compare commits
6 Commits
v3.2.5
...
jamison/sh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b32e2fd304 | ||
|
|
4a96ef13d7 | ||
|
|
822b0c99be | ||
|
|
bcf2851a85 | ||
|
|
a5a59bd8f0 | ||
|
|
32d2e7985a |
65
.devcontainer/Dockerfile
Normal file
65
.devcontainer/Dockerfile
Normal file
@@ -0,0 +1,65 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
acl \
|
||||
curl \
|
||||
fd-find \
|
||||
fzf \
|
||||
git \
|
||||
jq \
|
||||
less \
|
||||
make \
|
||||
neovim \
|
||||
openssh-client \
|
||||
python3-venv \
|
||||
ripgrep \
|
||||
sudo \
|
||||
ca-certificates \
|
||||
iptables \
|
||||
ipset \
|
||||
iproute2 \
|
||||
dnsutils \
|
||||
unzip \
|
||||
wget \
|
||||
zsh \
|
||||
&& curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
|
||||
&& apt-get install -y nodejs \
|
||||
&& install -m 0755 -d /etc/apt/keyrings \
|
||||
&& curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" > /etc/apt/sources.list.d/docker.list \
|
||||
&& curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg -o /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \
|
||||
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" > /etc/apt/sources.list.d/github-cli.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends docker-ce-cli docker-compose-plugin gh \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# fd-find installs as fdfind on Debian/Ubuntu — symlink to fd
|
||||
RUN ln -sf "$(which fdfind)" /usr/local/bin/fd
|
||||
|
||||
# Install uv (Python package manager)
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
# Create non-root dev user with passwordless sudo
|
||||
RUN useradd -m -s /bin/zsh dev && \
|
||||
echo "dev ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/dev && \
|
||||
chmod 0440 /etc/sudoers.d/dev
|
||||
|
||||
ENV DEVCONTAINER=true
|
||||
|
||||
RUN mkdir -p /workspace && \
|
||||
chown -R dev:dev /workspace
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install Claude Code
|
||||
ARG CLAUDE_CODE_VERSION=latest
|
||||
RUN npm install -g @anthropic-ai/claude-code@${CLAUDE_CODE_VERSION}
|
||||
|
||||
# Configure zsh — source the repo-local zshrc so shell customization
|
||||
# doesn't require an image rebuild.
|
||||
RUN chsh -s /bin/zsh root && \
|
||||
for rc in /root/.zshrc /home/dev/.zshrc; do \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
126
.devcontainer/README.md
Normal file
126
.devcontainer/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Onyx Dev Container
|
||||
|
||||
A containerized development environment for working on Onyx.
|
||||
|
||||
## What's included
|
||||
|
||||
- Ubuntu 26.04 base image
|
||||
- Node.js 20, uv, Claude Code
|
||||
- Docker CLI, GitHub CLI (`gh`)
|
||||
- Neovim, ripgrep, fd, fzf, jq, make, wget, unzip
|
||||
- Zsh as default shell (sources host `~/.zshrc` if available)
|
||||
- Python venv auto-activation
|
||||
- Network firewall (default-deny, whitelists npm, GitHub, Anthropic APIs, Sentry, and VS Code update servers)
|
||||
|
||||
## Usage
|
||||
|
||||
### VS Code
|
||||
|
||||
1. Install the [Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)
|
||||
2. Open this repo in VS Code
|
||||
3. "Reopen in Container" when prompted
|
||||
|
||||
### CLI (`ods dev`)
|
||||
|
||||
The [`ods` devtools CLI](../tools/ods/README.md) provides workspace-aware wrappers
|
||||
for all devcontainer operations (also available as `ods dc`):
|
||||
|
||||
```bash
|
||||
# Start the container
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec npm test
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
```
|
||||
|
||||
If you don't have `ods` installed, use the `devcontainer` CLI directly:
|
||||
|
||||
```bash
|
||||
npm install -g @devcontainers/cli
|
||||
|
||||
devcontainer up --workspace-folder .
|
||||
devcontainer exec --workspace-folder . zsh
|
||||
```
|
||||
|
||||
## Restarting the container
|
||||
|
||||
### VS Code
|
||||
|
||||
Open the Command Palette (`Ctrl+Shift+P` / `Cmd+Shift+P`) and run:
|
||||
|
||||
- **Dev Containers: Reopen in Container** — restarts the container without rebuilding
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull the latest published image and recreate
|
||||
ods dev rebuild
|
||||
```
|
||||
|
||||
Or without `ods`:
|
||||
|
||||
```bash
|
||||
devcontainer up --workspace-folder . --remove-existing-container
|
||||
```
|
||||
|
||||
## Image
|
||||
|
||||
The devcontainer uses a prebuilt image published to `onyxdotapp/onyx-devcontainer`.
|
||||
The tag is pinned in `devcontainer.json` — no local build is required.
|
||||
|
||||
To build the image locally (e.g. while iterating on the Dockerfile):
|
||||
|
||||
```bash
|
||||
docker buildx bake devcontainer
|
||||
```
|
||||
|
||||
The `devcontainer` target is defined in `docker-bake.hcl` at the repo root.
|
||||
|
||||
## User & permissions
|
||||
|
||||
The container runs as the `dev` user by default (`remoteUser` in devcontainer.json).
|
||||
An init script (`init-dev-user.sh`) runs at container start to ensure `dev` has
|
||||
read/write access to the bind-mounted workspace:
|
||||
|
||||
- **Standard Docker** — `dev`'s UID/GID is remapped to match the workspace owner,
|
||||
so file permissions work seamlessly.
|
||||
- **Rootless Docker** — The workspace appears as root-owned (UID 0) inside the
|
||||
container due to user-namespace mapping. The init script grants `dev` access via
|
||||
POSIX ACLs (`setfacl`), which adds a few seconds to the first container start on
|
||||
large repos.
|
||||
|
||||
## Docker socket
|
||||
|
||||
The container mounts the host's Docker socket so you can run `docker` commands
|
||||
from inside. `ods dev` auto-detects the socket path and sets `DOCKER_SOCK`:
|
||||
|
||||
| Environment | Socket path |
|
||||
| ----------------------- | ------------------------------ |
|
||||
| Linux (rootless Docker) | `$XDG_RUNTIME_DIR/docker.sock` |
|
||||
| macOS (Docker Desktop) | `~/.docker/run/docker.sock` |
|
||||
| Linux (standard Docker) | `/var/run/docker.sock` |
|
||||
|
||||
To override, set `DOCKER_SOCK` before running `ods dev up`. When using the
|
||||
VS Code extension or `devcontainer` CLI directly (without `ods`), you must set
|
||||
`DOCKER_SOCK` yourself.
|
||||
|
||||
## Firewall
|
||||
|
||||
The container starts with a default-deny firewall (`init-firewall.sh`) that only allows outbound traffic to:
|
||||
|
||||
- npm registry
|
||||
- GitHub
|
||||
- Anthropic API
|
||||
- Sentry
|
||||
- VS Code update servers
|
||||
|
||||
This requires the `NET_ADMIN` and `NET_RAW` capabilities, which are added via `runArgs` in `devcontainer.json`.
|
||||
22
.devcontainer/devcontainer.json
Normal file
22
.devcontainer/devcontainer.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:12184169c5bcc9cca0388286d5ffe504b569bc9c37bfa631b76ee8eee2064055",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
|
||||
"mounts": [
|
||||
"source=${localEnv:DOCKER_SOCK},target=/var/run/docker.sock,type=bind",
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
"source=${localEnv:HOME}/.zshrc,target=/home/dev/.zshrc.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.gitconfig,target=/home/dev/.gitconfig.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.ssh,target=/home/dev/.ssh.host,type=bind,readonly",
|
||||
"source=${localEnv:HOME}/.config/nvim,target=/home/dev/.config/nvim.host,type=bind,readonly",
|
||||
"source=onyx-devcontainer-cache,target=/home/dev/.cache,type=volume",
|
||||
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
|
||||
],
|
||||
"remoteUser": "dev",
|
||||
"updateRemoteUserUID": false,
|
||||
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",
|
||||
"waitFor": "postStartCommand"
|
||||
}
|
||||
106
.devcontainer/init-dev-user.sh
Normal file
106
.devcontainer/init-dev-user.sh
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Remap the dev user's UID/GID to match the workspace owner so that
|
||||
# bind-mounted files are accessible without running as root.
|
||||
#
|
||||
# Standard Docker: Workspace is owned by the host user's UID (e.g. 1000).
|
||||
# We remap dev to that UID -- fast and seamless.
|
||||
#
|
||||
# Rootless Docker: Workspace appears as root-owned (UID 0) inside the
|
||||
# container due to user-namespace mapping. We can't remap
|
||||
# dev to UID 0 (that's root), so we grant access with
|
||||
# POSIX ACLs instead.
|
||||
|
||||
WORKSPACE=/workspace
|
||||
TARGET_USER=dev
|
||||
|
||||
WS_UID=$(stat -c '%u' "$WORKSPACE")
|
||||
WS_GID=$(stat -c '%g' "$WORKSPACE")
|
||||
DEV_UID=$(id -u "$TARGET_USER")
|
||||
DEV_GID=$(id -g "$TARGET_USER")
|
||||
|
||||
DEV_HOME=/home/"$TARGET_USER"
|
||||
|
||||
# Ensure directories that tools expect exist under ~dev.
|
||||
# ~/.local and ~/.cache are named Docker volumes -- ensure they are owned by dev.
|
||||
mkdir -p "$DEV_HOME"/.local/state "$DEV_HOME"/.local/share
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.local
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME"/.cache
|
||||
|
||||
# Copy host configs mounted as *.host into their real locations.
|
||||
# This gives the dev user owned copies without touching host originals.
|
||||
if [ -d "$DEV_HOME/.ssh.host" ]; then
|
||||
cp -a "$DEV_HOME/.ssh.host" "$DEV_HOME/.ssh"
|
||||
chmod 700 "$DEV_HOME/.ssh"
|
||||
chmod 600 "$DEV_HOME"/.ssh/id_* 2>/dev/null || true
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.ssh"
|
||||
fi
|
||||
if [ -d "$DEV_HOME/.config/nvim.host" ]; then
|
||||
mkdir -p "$DEV_HOME/.config"
|
||||
cp -a "$DEV_HOME/.config/nvim.host" "$DEV_HOME/.config/nvim"
|
||||
chown -R "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.config/nvim"
|
||||
fi
|
||||
|
||||
# Already matching -- nothing to do.
|
||||
if [ "$WS_UID" = "$DEV_UID" ] && [ "$WS_GID" = "$DEV_GID" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$WS_UID" != "0" ]; then
|
||||
# ── Standard Docker ──────────────────────────────────────────────
|
||||
# Workspace is owned by a non-root UID (the host user).
|
||||
# Remap dev's UID/GID to match.
|
||||
if [ "$DEV_GID" != "$WS_GID" ]; then
|
||||
if ! groupmod -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER GID to $WS_GID" >&2
|
||||
fi
|
||||
fi
|
||||
if [ "$DEV_UID" != "$WS_UID" ]; then
|
||||
if ! usermod -u "$WS_UID" -g "$WS_GID" "$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to remap $TARGET_USER UID to $WS_UID" >&2
|
||||
fi
|
||||
fi
|
||||
if ! chown -R "$TARGET_USER":"$TARGET_USER" /home/"$TARGET_USER" 2>&1; then
|
||||
echo "warning: failed to chown /home/$TARGET_USER" >&2
|
||||
fi
|
||||
else
|
||||
# ── Rootless Docker ──────────────────────────────────────────────
|
||||
# Workspace is root-owned inside the container. Grant dev access
|
||||
# via POSIX ACLs (preserves ownership, works across the namespace
|
||||
# boundary).
|
||||
if command -v setfacl &>/dev/null; then
|
||||
setfacl -Rm "u:${TARGET_USER}:rwX" "$WORKSPACE"
|
||||
setfacl -Rdm "u:${TARGET_USER}:rwX" "$WORKSPACE" # default ACL for new files
|
||||
|
||||
# Git refuses to operate in repos owned by a different UID.
|
||||
# Host gitconfig is mounted readonly as ~/.gitconfig.host.
|
||||
# Create a real ~/.gitconfig that includes it plus container overrides.
|
||||
printf '[include]\n\tpath = %s/.gitconfig.host\n[safe]\n\tdirectory = %s\n' \
|
||||
"$DEV_HOME" "$WORKSPACE" > "$DEV_HOME/.gitconfig"
|
||||
chown "$TARGET_USER":"$TARGET_USER" "$DEV_HOME/.gitconfig"
|
||||
|
||||
# If this is a worktree, the main .git dir is bind-mounted at its
|
||||
# host absolute path. Grant dev access so git operations work.
|
||||
GIT_COMMON_DIR=$(git -C "$WORKSPACE" rev-parse --git-common-dir 2>/dev/null || true)
|
||||
if [ -n "$GIT_COMMON_DIR" ] && [ "$GIT_COMMON_DIR" != "$WORKSPACE/.git" ]; then
|
||||
[ ! -d "$GIT_COMMON_DIR" ] && GIT_COMMON_DIR="$WORKSPACE/$GIT_COMMON_DIR"
|
||||
if [ -d "$GIT_COMMON_DIR" ]; then
|
||||
setfacl -Rm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
|
||||
setfacl -Rdm "u:${TARGET_USER}:rwX" "$GIT_COMMON_DIR"
|
||||
git config -f "$DEV_HOME/.gitconfig" --add safe.directory "$(dirname "$GIT_COMMON_DIR")"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also fix bind-mounted dirs under ~dev that appear root-owned.
|
||||
dir="/home/${TARGET_USER}/.claude"
|
||||
if [ -d "$dir" ]; then
|
||||
setfacl -Rm "u:${TARGET_USER}:rwX" "$dir" && setfacl -Rdm "u:${TARGET_USER}:rwX" "$dir"
|
||||
fi
|
||||
[ -f /home/"$TARGET_USER"/.claude.json ] && \
|
||||
setfacl -m "u:${TARGET_USER}:rw" /home/"$TARGET_USER"/.claude.json
|
||||
else
|
||||
echo "warning: setfacl not found; dev user may not have write access to workspace" >&2
|
||||
echo " install the 'acl' package or set remoteUser to root" >&2
|
||||
fi
|
||||
fi
|
||||
104
.devcontainer/init-firewall.sh
Executable file
104
.devcontainer/init-firewall.sh
Executable file
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Preserve docker dns resolution
|
||||
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
|
||||
|
||||
# Flush all rules
|
||||
iptables -t nat -F
|
||||
iptables -t nat -X
|
||||
iptables -t mangle -F
|
||||
iptables -t mangle -X
|
||||
iptables -F
|
||||
iptables -X
|
||||
|
||||
# Restore docker dns rules
|
||||
if [ -n "$DOCKER_DNS_RULES" ]; then
|
||||
echo "$DOCKER_DNS_RULES" | iptables-restore -n
|
||||
fi
|
||||
|
||||
# Create ipset for allowed destinations
|
||||
ipset create allowed-domains hash:net || true
|
||||
ipset flush allowed-domains
|
||||
|
||||
# Fetch GitHub IP ranges (IPv4 only -- ipset hash:net and iptables are IPv4)
|
||||
GITHUB_IPS=$(curl -s https://api.github.com/meta | jq -r '.api[]' 2>/dev/null | grep -v ':' || echo "")
|
||||
for ip in $GITHUB_IPS; do
|
||||
if ! ipset add allowed-domains "$ip" -exist 2>&1; then
|
||||
echo "warning: failed to add GitHub IP $ip to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
|
||||
# Resolve allowed domains
|
||||
ALLOWED_DOMAINS=(
|
||||
"registry.npmjs.org"
|
||||
"api.anthropic.com"
|
||||
"api-staging.anthropic.com"
|
||||
"files.anthropic.com"
|
||||
"sentry.io"
|
||||
"update.code.visualstudio.com"
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"storage.googleapis.com"
|
||||
"static.rust-lang.org"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
IPS=$(getent ahosts "$domain" 2>/dev/null | awk '{print $1}' | grep -v ':' | sort -u || echo "")
|
||||
for ip in $IPS; do
|
||||
if ! ipset add allowed-domains "$ip/32" -exist 2>&1; then
|
||||
echo "warning: failed to add $domain ($ip) to allowlist" >&2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
# Detect host network
|
||||
if [[ "${DOCKER_HOST:-}" == "unix://"* ]]; then
|
||||
DOCKER_GATEWAY=$(ip -4 route show | grep "^default" | awk '{print $3}')
|
||||
if ! ipset add allowed-domains "$DOCKER_GATEWAY/32" -exist 2>&1; then
|
||||
echo "warning: failed to add Docker gateway $DOCKER_GATEWAY to allowlist" >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set default policies to DROP
|
||||
iptables -P FORWARD DROP
|
||||
iptables -P INPUT DROP
|
||||
iptables -P OUTPUT DROP
|
||||
|
||||
# Allow established connections
|
||||
iptables -A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
iptables -A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
|
||||
|
||||
# Allow loopback
|
||||
iptables -A INPUT -i lo -j ACCEPT
|
||||
iptables -A OUTPUT -o lo -j ACCEPT
|
||||
|
||||
# Allow DNS
|
||||
iptables -A OUTPUT -p udp --dport 53 -j ACCEPT
|
||||
iptables -A OUTPUT -p tcp --dport 53 -j ACCEPT
|
||||
|
||||
# Allow outbound to allowed destinations
|
||||
iptables -A OUTPUT -m set --match-set allowed-domains dst -j ACCEPT
|
||||
|
||||
# Reject unauthorized outbound
|
||||
iptables -A OUTPUT -j REJECT --reject-with icmp-host-unreachable
|
||||
|
||||
# Validate firewall configuration
|
||||
echo "Validating firewall configuration..."
|
||||
|
||||
BLOCKED_SITES=("example.com" "google.com" "facebook.com")
|
||||
for site in "${BLOCKED_SITES[@]}"; do
|
||||
if timeout 2 ping -c 1 "$site" &>/dev/null; then
|
||||
echo "Warning: $site is still reachable"
|
||||
fi
|
||||
done
|
||||
|
||||
if ! timeout 5 curl -s https://api.github.com/meta > /dev/null; then
|
||||
echo "Warning: GitHub API is not accessible"
|
||||
fi
|
||||
|
||||
echo "Firewall setup complete"
|
||||
10
.devcontainer/zshrc
Normal file
10
.devcontainer/zshrc
Normal file
@@ -0,0 +1,10 @@
|
||||
# Devcontainer zshrc — sourced automatically for both root and dev users.
|
||||
# Edit this file to customize the shell without rebuilding the image.
|
||||
|
||||
# Auto-activate Python venv
|
||||
if [ -f /workspace/.venv/bin/activate ]; then
|
||||
. /workspace/.venv/bin/activate
|
||||
fi
|
||||
|
||||
# Source host zshrc if bind-mounted
|
||||
[ -f ~/.zshrc.host ] && . ~/.zshrc.host
|
||||
@@ -86,6 +86,17 @@ repos:
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/shellcheck-py/shellcheck-py
|
||||
rev: 745eface02aef23e168a8afb6b5737818efbea95 # frozen: v0.11.0.1
|
||||
hooks:
|
||||
- id: shellcheck
|
||||
exclude: >-
|
||||
(?x)^(
|
||||
backend/scripts/setup_craft_templates\.sh|
|
||||
deployment/docker_compose/init-letsencrypt\.sh|
|
||||
deployment/docker_compose/install\.sh
|
||||
)$
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
|
||||
@@ -13,7 +13,6 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -108,13 +107,12 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role
|
||||
and the anonymous system user).
|
||||
|
||||
Only human accounts count toward seat limits.
|
||||
SERVICE_ACCOUNT (API key dummy users), EXT_PERM_USER, and the
|
||||
anonymous system user are excluded. BOT (Slack users) ARE counted
|
||||
because they represent real humans and get upgraded to STANDARD
|
||||
when they log in via web.
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
@@ -131,7 +129,6 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -11,8 +11,6 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -24,7 +22,6 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -68,25 +65,12 @@ from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -225,37 +209,12 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
@@ -633,6 +635,7 @@ def run_llm_loop(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -1017,16 +1020,20 @@ def run_llm_loop(
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
persisted_memory_id = update_memory_at_index(
|
||||
memory = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
persisted_memory_id = add_memory(
|
||||
memory = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
|
||||
@@ -826,12 +826,6 @@ def translate_history_to_llm_format(
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
content_parts.append(
|
||||
TextContentPart(
|
||||
type="text",
|
||||
text=f"[attached image — file_id: {img_file.file_id}]",
|
||||
)
|
||||
)
|
||||
image_part = ImageContentPart(
|
||||
type="image_url",
|
||||
image_url=ImageUrlDetail(
|
||||
|
||||
@@ -67,6 +67,7 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -1005,86 +1006,93 @@ def _run_models(
|
||||
model_llm = setup.llms[model_idx]
|
||||
|
||||
try:
|
||||
# Each function opens short-lived DB sessions on demand.
|
||||
# Do NOT pass a long-lived session here — it would hold a
|
||||
# connection for the entire LLM loop (minutes), and cloud
|
||||
# infrastructure may drop idle connections.
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool for tool_list in thread_tool_dict.values() for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
# Each worker opens its own session — SQLAlchemy sessions are not thread-safe.
|
||||
# Do NOT write to the outer db_session (or any shared DB state) from here;
|
||||
# all DB writes in this thread must go through thread_db_session.
|
||||
with get_session_with_current_tenant() as thread_db_session:
|
||||
thread_tool_dict = construct_tools(
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
db_session=thread_db_session,
|
||||
emitter=model_emitter,
|
||||
user=user,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=setup.new_msg_req.internal_search_filters,
|
||||
project_id_filter=setup.search_params.project_id_filter,
|
||||
persona_id_filter=setup.search_params.persona_id_filter,
|
||||
bypass_acl=setup.bypass_acl,
|
||||
slack_context=setup.slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
setup.persona, setup.new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=setup.chat_session.id,
|
||||
message_id=setup.user_message.id,
|
||||
additional_headers=setup.custom_tool_additional_headers,
|
||||
mcp_headers=setup.mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=setup.available_files.user_file_ids,
|
||||
chat_file_ids=setup.available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=setup.new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=setup.search_params.search_usage,
|
||||
)
|
||||
model_tools = [
|
||||
tool
|
||||
for tool_list in thread_tool_dict.values()
|
||||
for tool in tool_list
|
||||
]
|
||||
|
||||
if setup.forced_tool_id and setup.forced_tool_id not in {
|
||||
tool.id for tool in model_tools
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Forced tool {setup.forced_tool_id} not found in tools"
|
||||
)
|
||||
|
||||
# Per-thread copy: run_llm_loop mutates simple_chat_history in-place.
|
||||
if n_models == 1 and setup.new_msg_req.deep_research:
|
||||
if setup.chat_session.project_id:
|
||||
raise RuntimeError(
|
||||
"Deep research is not supported for projects"
|
||||
)
|
||||
run_deep_research_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
skip_clarification=setup.skip_clarification,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
run_llm_loop(
|
||||
emitter=model_emitter,
|
||||
state_container=sc,
|
||||
simple_chat_history=list(setup.simple_chat_history),
|
||||
tools=model_tools,
|
||||
custom_agent_prompt=setup.custom_agent_prompt,
|
||||
context_files=setup.extracted_context_files,
|
||||
persona=setup.persona,
|
||||
user_memory_context=setup.user_memory_context,
|
||||
llm=model_llm,
|
||||
token_counter=get_llm_token_counter(model_llm),
|
||||
db_session=thread_db_session,
|
||||
forced_tool_id=setup.forced_tool_id,
|
||||
user_identity=setup.user_identity,
|
||||
chat_session_id=str(setup.chat_session.id),
|
||||
chat_files=setup.chat_files_for_tools,
|
||||
include_citations=setup.new_msg_req.include_citations,
|
||||
all_injected_file_metadata=setup.all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import ParseResult
|
||||
@@ -54,21 +53,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
"""Accept both the current (dict) and legacy (JSON string) KV payload shapes.
|
||||
|
||||
Payloads written before the fix for serializing Google credentials into
|
||||
``EncryptedJson`` columns are stored as JSON strings; new writes store dicts.
|
||||
Once every install has re-uploaded their Google credentials the legacy
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect(source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
@@ -178,13 +162,12 @@ def build_service_account_creds(
|
||||
|
||||
def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
credential_json = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
credential_json = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=GOOGLE_SCOPES[source],
|
||||
@@ -205,12 +188,12 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
|
||||
|
||||
def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleAppCredentials(**creds)
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(
|
||||
@@ -218,14 +201,10 @@ def upsert_google_app_cred(
|
||||
) -> None:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_CRED_KEY,
|
||||
app_credentials.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_CRED_KEY, app_credentials.model_dump(mode="json"), encrypt=True
|
||||
)
|
||||
get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -241,14 +220,12 @@ def delete_google_app_cred(source: DocumentSource) -> None:
|
||||
|
||||
def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey:
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
creds = _load_google_json(
|
||||
get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
elif source == DocumentSource.GMAIL:
|
||||
creds = _load_google_json(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
return GoogleServiceAccountKey(**creds)
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(
|
||||
@@ -257,14 +234,12 @@ def upsert_service_account_key(
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
get_kv_store().store(
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
service_account_key.json(),
|
||||
encrypt=True,
|
||||
)
|
||||
elif source == DocumentSource.GMAIL:
|
||||
get_kv_store().store(
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY,
|
||||
service_account_key.model_dump(mode="json"),
|
||||
encrypt=True,
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source: {source}")
|
||||
|
||||
@@ -60,10 +60,8 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
_JIRA_BULK_FETCH_LIMIT = 100
|
||||
|
||||
# Constants for Jira field names
|
||||
_FIELD_REPORTER = "reporter"
|
||||
@@ -257,13 +255,15 @@ def _bulk_fetch_request(
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def _bulk_fetch_batch(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch a single batch (must be <= _JIRA_BULK_FETCH_LIMIT).
|
||||
On JSONDecodeError, recursively bisects until it succeeds or reaches size 1."""
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
return _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
@@ -277,25 +277,12 @@ def _bulk_fetch_batch(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = _bulk_fetch_batch(jira_client, issue_ids[:mid], fields)
|
||||
right = _bulk_fetch_batch(jira_client, issue_ids[mid:], fields)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
raw_issues: list[dict[str, Any]] = []
|
||||
for batch in chunked(issue_ids, _JIRA_BULK_FETCH_LIMIT):
|
||||
try:
|
||||
raw_issues.extend(_bulk_fetch_batch(jira_client, list(batch), fields))
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -7,14 +6,6 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -50,6 +49,7 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,6 +58,7 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -420,94 +421,6 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -519,6 +432,7 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -748,6 +662,7 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -780,37 +695,62 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
"""Build the thread text from messages."""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
for msg in replies:
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -1036,16 +976,7 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -1062,16 +993,8 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1087,7 +1010,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
)
|
||||
]
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -639,38 +638,12 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
) -> list[str]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -695,15 +668,6 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -720,9 +684,7 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -740,8 +702,7 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
search_queries = [
|
||||
return [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
@@ -347,25 +346,6 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _safe_close_session(session: Session) -> None:
|
||||
"""Close a session, catching connection-closed errors during cleanup.
|
||||
|
||||
Long-running operations (e.g. multi-model LLM loops) can hold a session
|
||||
open for minutes. If the underlying connection is dropped by cloud
|
||||
infrastructure (load-balancer timeouts, PgBouncer, idle-in-transaction
|
||||
timeouts, etc.), the implicit rollback in Session.close() raises
|
||||
OperationalError or InterfaceError. Since the work is already complete,
|
||||
we log and move on — SQLAlchemy internally invalidates the connection
|
||||
for pool recycling.
|
||||
"""
|
||||
try:
|
||||
session.close()
|
||||
except DBAPIError:
|
||||
logger.warning(
|
||||
"DB connection lost during session cleanup — the connection will be invalidated and recycled by the pool."
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
@@ -378,11 +358,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
|
||||
# no need to use the schema translation map for self-hosted + default schema
|
||||
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
||||
session = Session(bind=engine, expire_on_commit=False)
|
||||
try:
|
||||
with Session(bind=engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
return
|
||||
|
||||
# Create connection with schema translation to handle querying the right schema
|
||||
@@ -390,11 +367,8 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
session = Session(bind=connection, expire_on_commit=False)
|
||||
try:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
finally:
|
||||
_safe_close_session(session)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -84,51 +83,47 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session | None = None,
|
||||
) -> int:
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
|
||||
Returns the id of the newly created Memory row.
|
||||
"""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session | None = None,
|
||||
) -> int | None:
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the id of the updated Memory row, or None if the index is out of range.
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory.id
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
@@ -7,6 +7,8 @@ import time
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
@@ -20,7 +22,6 @@ from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
@@ -183,14 +184,6 @@ def generate_final_report(
|
||||
return has_reasoned
|
||||
|
||||
|
||||
def _get_research_agent_tool_id() -> int:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -200,6 +193,7 @@ def run_deep_research_llm_loop(
|
||||
custom_agent_prompt: str | None, # noqa: ARG001
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
@@ -723,7 +717,6 @@ def run_deep_research_llm_loop(
|
||||
simple_chat_history.append(assistant_with_tools)
|
||||
|
||||
# Now add TOOL_CALL_RESPONSE messages and tool call info for each result
|
||||
research_agent_tool_id = _get_research_agent_tool_id()
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
@@ -744,7 +737,10 @@ def run_deep_research_llm_loop(
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=research_agent_tool_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
|
||||
@@ -1516,10 +1516,6 @@
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-7": {
|
||||
"display_name": "Claude Opus 4.7",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -46,15 +46,6 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.HIGH: 4096,
|
||||
}
|
||||
|
||||
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
|
||||
# output_config.effort instead of thinking.type.enabled + budget_tokens.
|
||||
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.AUTO: "medium",
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
@@ -68,13 +67,8 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
)
|
||||
|
||||
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
|
||||
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
|
||||
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
@@ -236,14 +230,6 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
adaptive_model in normalized_model_name
|
||||
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -523,6 +509,10 @@ class LitellmLLM(LLM):
|
||||
}
|
||||
|
||||
elif is_claude_model:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
@@ -530,35 +520,24 @@ class LitellmLLM(LLM):
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if _anthropic_uses_adaptive_thinking(self.config.model_name):
|
||||
# Newer Anthropic models (Claude Opus 4.7+) reject
|
||||
# thinking.type.enabled — they require the adaptive
|
||||
# thinking config with output_config.effort.
|
||||
if not has_tool_call_history:
|
||||
optional_kwargs["thinking"] = {"type": "adaptive"}
|
||||
optional_kwargs["output_config"] = {
|
||||
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
else:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
if budget_tokens is not None and not has_tool_call_history:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
|
||||
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
|
||||
optional_kwargs.pop("reasoning_effort", None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"version": "1.2",
|
||||
"updated_at": "2026-04-16T00:00:00Z",
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
@@ -10,12 +10,8 @@
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
"default_model": "claude-opus-4-7",
|
||||
"default_model": "claude-opus-4-6",
|
||||
"additional_visible_models": [
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"display_name": "Claude Opus 4.7"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
|
||||
@@ -65,9 +65,8 @@ IMPORTANT: each call to this tool is independent. Variables from previous calls
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
To edit, restyle, or vary an existing image, pass its file_id in `reference_image_file_ids`. \
|
||||
File IDs come from `[attached image — file_id: <id>]` tags on user-attached images or from prior `generate_image` tool results — never invent one. \
|
||||
Leave `reference_image_file_ids` unset for a fresh generation.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
@@ -618,7 +618,6 @@ done
|
||||
"app.kubernetes.io/managed-by": "onyx",
|
||||
"onyx.app/sandbox-id": sandbox_id,
|
||||
"onyx.app/tenant-id": tenant_id,
|
||||
"admission.datadoghq.com/enabled": "false",
|
||||
},
|
||||
),
|
||||
spec=pod_spec,
|
||||
|
||||
@@ -96,32 +96,6 @@ def _truncate_description(description: str | None, max_length: int = 500) -> str
|
||||
return description[: max_length - 3] + "..."
|
||||
|
||||
|
||||
# TODO: Replace mask-comparison approach with an explicit Unset sentinel from the
|
||||
# frontend indicating whether each credential field was actually modified. The current
|
||||
# approach is brittle (e.g. short credentials produce a fixed-length mask that could
|
||||
# collide) and mutates request values, which is surprising. The frontend should signal
|
||||
# "unchanged" vs "new value" directly rather than relying on masked-string equality.
|
||||
def _restore_masked_oauth_credentials(
|
||||
request_client_id: str | None,
|
||||
request_client_secret: str | None,
|
||||
existing_client: OAuthClientInformationFull,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""If the frontend sent back masked credentials, restore the real stored values."""
|
||||
if (
|
||||
request_client_id
|
||||
and existing_client.client_id
|
||||
and request_client_id == mask_string(existing_client.client_id)
|
||||
):
|
||||
request_client_id = existing_client.client_id
|
||||
if (
|
||||
request_client_secret
|
||||
and existing_client.client_secret
|
||||
and request_client_secret == mask_string(existing_client.client_secret)
|
||||
):
|
||||
request_client_secret = existing_client.client_secret
|
||||
return request_client_id, request_client_secret
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mcp")
|
||||
admin_router = APIRouter(prefix="/admin/mcp")
|
||||
STATE_TTL_SECONDS = 60 * 5 # 5 minutes
|
||||
@@ -418,26 +392,6 @@ async def _connect_oauth(
|
||||
detail=f"Server was configured with authentication type {auth_type_str}",
|
||||
)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so we don't overwrite them with masks.
|
||||
if mcp_server.admin_connection_config:
|
||||
existing_data = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
existing_client_raw = existing_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if existing_client_raw:
|
||||
existing_client = OAuthClientInformationFull.model_validate(
|
||||
existing_client_raw
|
||||
)
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
existing_client,
|
||||
)
|
||||
|
||||
# Create admin config with client info if provided
|
||||
config_data = MCPConnectionData(headers={})
|
||||
if request.oauth_client_id and request.oauth_client_secret:
|
||||
@@ -1402,19 +1356,6 @@ def _upsert_mcp_server(
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
|
||||
# If the frontend sent back masked credentials (unchanged by the user),
|
||||
# restore the real stored values so the comparison below sees no change
|
||||
# and the credentials aren't overwritten with masked strings.
|
||||
if client_info and request.auth_type == MCPAuthenticationType.OAUTH:
|
||||
(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
) = _restore_masked_oauth_credentials(
|
||||
request.oauth_client_id,
|
||||
request.oauth_client_secret,
|
||||
client_info,
|
||||
)
|
||||
|
||||
changing_connection_config = (
|
||||
not mcp_server.admin_connection_config
|
||||
or (
|
||||
|
||||
@@ -11,9 +11,6 @@ from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.build.utils import ensure_build_mode_intro_notification
|
||||
from onyx.server.features.notifications.utils import (
|
||||
ensure_permissions_migration_notification,
|
||||
)
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
@@ -52,13 +49,6 @@ def get_notifications_api(
|
||||
except Exception:
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
try:
|
||||
ensure_permissions_migration_notification(user, db_session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create permissions_migration_v1 announcement in notifications endpoint"
|
||||
)
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
|
||||
|
||||
def ensure_permissions_migration_notification(user: User, db_session: Session) -> None:
|
||||
# Feature id "permissions_migration_v1" must not change after shipping —
|
||||
# it is the dedup key on (user_id, notif_type, additional_data).
|
||||
create_notification(
|
||||
user_id=user.id,
|
||||
notif_type=NotificationType.FEATURE_ANNOUNCEMENT,
|
||||
db_session=db_session,
|
||||
title="Permissions are changing in Onyx",
|
||||
description="Roles are moving to group-based permissions. Click for details.",
|
||||
additional_data={
|
||||
"feature": "permissions_migration_v1",
|
||||
"link": "https://docs.onyx.app/admins/permissions/whats_changing",
|
||||
},
|
||||
)
|
||||
@@ -111,43 +111,6 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _resolve_api_key(
|
||||
api_key: str | None,
|
||||
provider_name: str | None,
|
||||
api_base: str | None,
|
||||
db_session: Session,
|
||||
) -> str | None:
|
||||
"""Return the real API key for model-fetch endpoints.
|
||||
|
||||
When editing an existing provider the form value is masked (e.g.
|
||||
``sk-a****b1c2``). If *provider_name* is supplied we can look up
|
||||
the unmasked key from the database so the external request succeeds.
|
||||
|
||||
The stored key is only returned when the request's *api_base*
|
||||
matches the value stored in the database.
|
||||
"""
|
||||
if not provider_name:
|
||||
return api_key
|
||||
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.api_key:
|
||||
# Normalise both URLs before comparing so trailing-slash
|
||||
# differences don't cause a false mismatch.
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
request_base = (api_base or "").strip().rstrip("/")
|
||||
if stored_base != request_base:
|
||||
return api_key
|
||||
|
||||
stored_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
# Only resolve when the incoming value is the masked form of the
|
||||
# stored key — i.e. the user hasn't typed a new key.
|
||||
if api_key and api_key == _mask_string(stored_key):
|
||||
return stored_key
|
||||
return api_key
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
@@ -1211,17 +1174,16 @@ def get_ollama_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str | None) -> dict:
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers: dict[str, str] = {
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
@@ -1244,12 +1206,8 @@ def get_openrouter_available_models(
|
||||
Parses id, name (display), context_length, and architecture.input_modalities.
|
||||
"""
|
||||
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
@@ -1342,18 +1300,13 @@ def get_lm_studio_available_models(
|
||||
|
||||
# If provider_name is given and the api_key hasn't been changed by the user,
|
||||
# fall back to the stored API key from the database (the form value is masked).
|
||||
# Only do so when the api_base matches what is stored.
|
||||
api_key = request.api_key
|
||||
if request.provider_name and not request.api_key_changed:
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=request.provider_name, db_session=db_session
|
||||
)
|
||||
if existing_provider and existing_provider.custom_config:
|
||||
stored_base = (existing_provider.api_base or "").strip().rstrip("/")
|
||||
if stored_base == cleaned_api_base:
|
||||
api_key = existing_provider.custom_config.get(
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY
|
||||
)
|
||||
api_key = existing_provider.custom_config.get(LM_STUDIO_API_KEY_CONFIG_KEY)
|
||||
|
||||
url = f"{cleaned_api_base}/api/v1/models"
|
||||
headers: dict[str, str] = {}
|
||||
@@ -1437,12 +1390,8 @@ def get_litellm_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=api_key, api_base=request.api_base
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1499,7 +1448,7 @@ def get_litellm_available_models(
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str | None, api_base: str) -> dict:
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
@@ -1574,12 +1523,8 @@ def get_bifrost_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BifrostFinalModelResponse]:
|
||||
"""Fetch available models from Bifrost gateway /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_bifrost_models_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
@@ -1668,12 +1613,8 @@ def get_openai_compatible_server_available_models(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
api_key = _resolve_api_key(
|
||||
request.api_key, request.provider_name, request.api_base, db_session
|
||||
)
|
||||
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=api_key
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
|
||||
@@ -183,9 +183,6 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
"qwen2.5:7b" → "Qwen 2.5 7B"
|
||||
"mistral:latest" → "Mistral"
|
||||
"deepseek-r1:14b" → "DeepSeek R1 14B"
|
||||
"gemma4:e4b" → "Gemma 4 E4B"
|
||||
"deepseek-v3.1:671b-cloud" → "DeepSeek V3.1 671B Cloud"
|
||||
"qwen3-vl:235b-instruct-cloud" → "Qwen 3-vl 235B Instruct Cloud"
|
||||
"""
|
||||
# Split into base name and tag
|
||||
if ":" in model_name:
|
||||
@@ -212,24 +209,13 @@ def generate_ollama_display_name(model_name: str) -> str:
|
||||
# Default: Title case with dashes converted to spaces
|
||||
display_name = base.replace("-", " ").title()
|
||||
|
||||
# Process tag (skip "latest")
|
||||
# Process tag to extract size info (skip "latest")
|
||||
if tag and tag.lower() != "latest":
|
||||
# Check for size prefix like "7b", "70b", optionally followed by modifiers
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])(-.+)?$", tag)
|
||||
# Extract size like "7b", "70b", "14b"
|
||||
size_match = re.match(r"^(\d+(?:\.\d+)?[bBmM])", tag)
|
||||
if size_match:
|
||||
size = size_match.group(1).upper()
|
||||
remainder = size_match.group(2)
|
||||
if remainder:
|
||||
# Format modifiers like "-cloud", "-instruct-cloud"
|
||||
modifiers = " ".join(
|
||||
p.title() for p in remainder.strip("-").split("-") if p
|
||||
)
|
||||
display_name = f"{display_name} {size} {modifiers}"
|
||||
else:
|
||||
display_name = f"{display_name} {size}"
|
||||
else:
|
||||
# Non-size tags like "e4b", "q4_0", "fp16", "cloud"
|
||||
display_name = f"{display_name} {tag.upper()}"
|
||||
display_name = f"{display_name} {size}"
|
||||
|
||||
return display_name
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import json
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
@@ -114,47 +113,28 @@ async def transcribe_audio(
|
||||
) from exc
|
||||
|
||||
|
||||
def _extract_provider_error(exc: Exception) -> str:
|
||||
"""Extract a human-readable message from a provider exception.
|
||||
|
||||
Provider errors often embed JSON from upstream APIs (e.g. ElevenLabs).
|
||||
This tries to parse a readable ``message`` field out of common JSON
|
||||
error shapes; falls back to ``str(exc)`` if nothing better is found.
|
||||
"""
|
||||
raw = str(exc)
|
||||
try:
|
||||
# Many providers embed JSON after a prefix like "ElevenLabs TTS failed: {...}"
|
||||
json_start = raw.find("{")
|
||||
if json_start == -1:
|
||||
return raw
|
||||
parsed = json.loads(raw[json_start:])
|
||||
# Shape: {"detail": {"message": "..."}} (ElevenLabs)
|
||||
detail = parsed.get("detail", parsed)
|
||||
if isinstance(detail, dict):
|
||||
return detail.get("message") or detail.get("error") or raw
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
except (json.JSONDecodeError, AttributeError, TypeError):
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str = Field(..., min_length=1)
|
||||
voice: str | None = None
|
||||
speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
body: SynthesizeRequest,
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(require_permission(Permission.BASIC_ACCESS)),
|
||||
) -> StreamingResponse:
|
||||
"""Synthesize text to speech using the default TTS provider."""
|
||||
text = body.text
|
||||
voice = body.voice
|
||||
speed = body.speed
|
||||
logger.info(f"TTS request: text length={len(text)}, voice={voice}, speed={speed}")
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
@@ -197,36 +177,31 @@ async def synthesize_speech(
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Pull the first chunk before returning the StreamingResponse. If the
|
||||
# provider rejects the request (e.g. text too long), the error surfaces
|
||||
# as a proper HTTP error instead of a broken audio stream.
|
||||
stream_iter = provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
)
|
||||
try:
|
||||
first_chunk = await stream_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, "TTS provider returned no audio")
|
||||
except Exception as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, _extract_provider_error(exc)
|
||||
) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
yield first_chunk
|
||||
chunk_count = 1
|
||||
async for chunk in stream_iter:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
multi_model_chat_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Whether EE features are unlocked for use.
|
||||
@@ -90,8 +89,7 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB, ge=0
|
||||
)
|
||||
file_token_count_threshold_k: int | None = Field(
|
||||
default=None,
|
||||
ge=0, # thousands of tokens; None = context-aware default
|
||||
default=None, ge=0 # thousands of tokens; None = context-aware default
|
||||
)
|
||||
|
||||
# Connector settings
|
||||
|
||||
@@ -208,6 +208,12 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -114,10 +113,10 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
db_session: Session | None = None,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
@@ -132,33 +131,6 @@ def construct_tools(
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
return _construct_tools_impl(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
custom_tool_config=custom_tool_config,
|
||||
file_reader_tool_config=file_reader_tool_config,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
|
||||
|
||||
def _construct_tools_impl(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
emitter: Emitter,
|
||||
user: User,
|
||||
llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
file_reader_tool_config: FileReaderToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -47,7 +48,7 @@ PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -141,11 +142,8 @@ class ImageGenerationTool(Tool[None]):
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional file_ids of existing images to edit or use as reference;"
|
||||
" the first is the primary edit source."
|
||||
" Get file_ids from `[attached image — file_id: <id>]` tags on"
|
||||
" user-attached images or from prior generate_image tool responses."
|
||||
" Omit for a fresh, unrelated generation."
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
@@ -256,31 +254,41 @@ class ImageGenerationTool(Tool[None]):
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is None:
|
||||
# No references requested — plain generation.
|
||||
return []
|
||||
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
|
||||
# Deduplicate while preserving order (first occurrence wins, so the
|
||||
# LLM's intended "primary edit source" stays at index 0).
|
||||
# Deduplicate while preserving order.
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in raw_reference_ids:
|
||||
file_id = file_id.strip()
|
||||
if not file_id or file_id in seen_ids:
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
@@ -294,14 +302,14 @@ class ImageGenerationTool(Tool[None]):
|
||||
f"Reference images requested but provider '{self.provider}' does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from existing images. "
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[:max_reference_images]
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
@@ -350,7 +358,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -365,6 +373,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -13,6 +14,7 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
@@ -22,6 +24,9 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
@@ -105,6 +110,63 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _extract_image_file_ids_from_tool_response_message(
|
||||
message: str,
|
||||
) -> list[str]:
|
||||
try:
|
||||
parsed_message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
parsed_items: list[Any] = (
|
||||
parsed_message if isinstance(parsed_message, list) else [parsed_message]
|
||||
)
|
||||
file_ids: list[str] = []
|
||||
for item in parsed_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
file_id = item.get("file_id")
|
||||
if isinstance(file_id, str):
|
||||
file_ids.append(file_id)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
def _extract_recent_generated_image_file_ids(
|
||||
message_history: list[ChatMessageSimple],
|
||||
) -> list[str]:
|
||||
tool_name_by_tool_call_id: dict[str, str] = {}
|
||||
recent_image_file_ids: list[str] = []
|
||||
seen_file_ids: set[str] = set()
|
||||
|
||||
for message in message_history:
|
||||
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
|
||||
continue
|
||||
|
||||
if (
|
||||
message.message_type != MessageType.TOOL_CALL_RESPONSE
|
||||
or not message.tool_call_id
|
||||
):
|
||||
continue
|
||||
|
||||
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
|
||||
if tool_name != ImageGenerationTool.NAME:
|
||||
continue
|
||||
|
||||
for file_id in _extract_image_file_ids_from_tool_response_message(
|
||||
message.message
|
||||
):
|
||||
if file_id in seen_file_ids:
|
||||
continue
|
||||
seen_file_ids.add(file_id)
|
||||
recent_image_file_ids.append(file_id)
|
||||
|
||||
return recent_image_file_ids
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
@@ -324,6 +386,9 @@ def run_tool_calls(
|
||||
url_to_citation: dict[str, int] = {
|
||||
url: citation_num for citation_num, url in citation_mapping.items()
|
||||
}
|
||||
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Prepare all tool calls with their override_kwargs
|
||||
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
|
||||
@@ -340,6 +405,7 @@ def run_tool_calls(
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| PythonToolOverrideKwargs
|
||||
| ImageGenerationToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
@@ -388,6 +454,10 @@ def run_tool_calls(
|
||||
override_kwargs = PythonToolOverrideKwargs(
|
||||
chat_files=chat_files or [],
|
||||
)
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
override_kwargs = ImageGenerationToolOverrideKwargs(
|
||||
recent_generated_image_file_ids=recent_generated_image_file_ids
|
||||
)
|
||||
elif isinstance(tool, MemoryTool):
|
||||
override_kwargs = MemoryToolOverrideKwargs(
|
||||
user_name=(
|
||||
|
||||
@@ -254,7 +254,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.3
|
||||
onyx-devtools==0.7.4
|
||||
openai==2.14.0
|
||||
# via
|
||||
# litellm
|
||||
|
||||
@@ -46,7 +46,7 @@ stop_and_remove_containers
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
if [[ -n "$POSTGRES_VOLUME" ]]; then
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v "$POSTGRES_VOLUME":/var/lib/postgresql/data postgres -c max_connections=250
|
||||
else
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
|
||||
fi
|
||||
@@ -54,7 +54,7 @@ fi
|
||||
# Start the Vespa container with optional volume
|
||||
echo "Starting Vespa container..."
|
||||
if [[ -n "$VESPA_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v $VESPA_VOLUME:/opt/vespa/var vespaengine/vespa:8
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v "$VESPA_VOLUME":/opt/vespa/var vespaengine/vespa:8
|
||||
else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
@@ -85,7 +85,7 @@ docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-en
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v "$REDIS_VOLUME":/data redis
|
||||
else
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 redis
|
||||
fi
|
||||
@@ -93,7 +93,7 @@ fi
|
||||
# Start the MinIO container with optional volume
|
||||
echo "Starting MinIO container..."
|
||||
if [[ -n "$MINIO_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v $MINIO_VOLUME:/data minio/minio server /data --console-address ":9001"
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v "$MINIO_VOLUME":/data minio/minio server /data --console-address ":9001"
|
||||
else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
@@ -111,6 +111,7 @@ sleep 1
|
||||
|
||||
# Alembic should be configured in the virtualenv for this repo
|
||||
if [[ -f "../.venv/bin/activate" ]]; then
|
||||
# shellcheck source=/dev/null
|
||||
source ../.venv/bin/activate
|
||||
else
|
||||
echo "Warning: Python virtual environment not found at .venv/bin/activate; alembic may not work."
|
||||
|
||||
@@ -38,41 +38,38 @@ class TestAddMemory:
|
||||
def test_add_memory_creates_row(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that add_memory inserts a new Memory row."""
|
||||
user_id = test_user.id
|
||||
memory_id = add_memory(
|
||||
memory = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="User prefers dark mode",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert memory_id is not None
|
||||
assert memory.id is not None
|
||||
assert memory.user_id == user_id
|
||||
assert memory.memory_text == "User prefers dark mode"
|
||||
|
||||
# Verify it persists
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
fetched = db_session.get(Memory, memory.id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == user_id
|
||||
assert fetched.memory_text == "User prefers dark mode"
|
||||
|
||||
def test_add_multiple_memories(self, db_session: Session, test_user: User) -> None:
|
||||
"""Verify that multiple memories can be added for the same user."""
|
||||
user_id = test_user.id
|
||||
m1_id = add_memory(
|
||||
m1 = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Favorite color is blue",
|
||||
db_session=db_session,
|
||||
)
|
||||
m2_id = add_memory(
|
||||
m2 = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Works in engineering",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert m1_id != m2_id
|
||||
fetched_m1 = db_session.get(Memory, m1_id)
|
||||
fetched_m2 = db_session.get(Memory, m2_id)
|
||||
assert fetched_m1 is not None
|
||||
assert fetched_m2 is not None
|
||||
assert fetched_m1.memory_text == "Favorite color is blue"
|
||||
assert fetched_m2.memory_text == "Works in engineering"
|
||||
assert m1.id != m2.id
|
||||
assert m1.memory_text == "Favorite color is blue"
|
||||
assert m2.memory_text == "Works in engineering"
|
||||
|
||||
|
||||
class TestUpdateMemoryAtIndex:
|
||||
@@ -85,17 +82,15 @@ class TestUpdateMemoryAtIndex:
|
||||
add_memory(user_id=user_id, memory_text="Memory 1", db_session=db_session)
|
||||
add_memory(user_id=user_id, memory_text="Memory 2", db_session=db_session)
|
||||
|
||||
updated_id = update_memory_at_index(
|
||||
updated = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=1,
|
||||
new_text="Updated Memory 1",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert updated_id is not None
|
||||
fetched = db_session.get(Memory, updated_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Updated Memory 1"
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated Memory 1"
|
||||
|
||||
def test_update_memory_at_out_of_range_index(
|
||||
self, db_session: Session, test_user: User
|
||||
@@ -172,7 +167,7 @@ class TestMemoryCap:
|
||||
assert len(rows_before) == MAX_MEMORIES_PER_USER
|
||||
|
||||
# Add one more — should evict the oldest
|
||||
new_memory_id = add_memory(
|
||||
new_memory = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="New memory after cap",
|
||||
db_session=db_session,
|
||||
@@ -186,7 +181,7 @@ class TestMemoryCap:
|
||||
# Oldest ("Memory 0") should be gone; "Memory 1" is now the oldest
|
||||
assert rows_after[0].memory_text == "Memory 1"
|
||||
# Newest should be the one we just added
|
||||
assert rows_after[-1].id == new_memory_id
|
||||
assert rows_after[-1].id == new_memory.id
|
||||
assert rows_after[-1].memory_text == "New memory after cap"
|
||||
|
||||
|
||||
@@ -226,26 +221,22 @@ class TestGetMemoriesWithUserId:
|
||||
user_id = test_user_no_memories.id
|
||||
|
||||
# Add a memory
|
||||
memory_id = add_memory(
|
||||
memory = add_memory(
|
||||
user_id=user_id,
|
||||
memory_text="Memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
fetched = db_session.get(Memory, memory_id)
|
||||
assert fetched is not None
|
||||
assert fetched.memory_text == "Memory with use_memories off"
|
||||
assert memory.memory_text == "Memory with use_memories off"
|
||||
|
||||
# Update that memory
|
||||
updated_id = update_memory_at_index(
|
||||
updated = update_memory_at_index(
|
||||
user_id=user_id,
|
||||
index=0,
|
||||
new_text="Updated memory with use_memories off",
|
||||
db_session=db_session,
|
||||
)
|
||||
assert updated_id is not None
|
||||
fetched_updated = db_session.get(Memory, updated_id)
|
||||
assert fetched_updated is not None
|
||||
assert fetched_updated.memory_text == "Updated memory with use_memories off"
|
||||
assert updated is not None
|
||||
assert updated.memory_text == "Updated memory with use_memories off"
|
||||
|
||||
# Verify get_memories returns the updated memory
|
||||
context = get_memories(test_user_no_memories, db_session)
|
||||
|
||||
@@ -9,7 +9,6 @@ from unittest.mock import patch
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_used_seats
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
@@ -215,43 +214,3 @@ class TestCheckSeatAvailabilityMultiTenant:
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
|
||||
class TestGetUsedSeatsAccountTypeFiltering:
|
||||
"""Verify get_used_seats query excludes SERVICE_ACCOUNT but includes BOT."""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_excludes_service_accounts(self, mock_get_session: MagicMock) -> None:
|
||||
"""SERVICE_ACCOUNT users should not count toward seats."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 5
|
||||
|
||||
result = get_used_seats()
|
||||
|
||||
assert result == 5
|
||||
# Inspect the compiled query to verify account_type filter
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "SERVICE_ACCOUNT" in compiled
|
||||
# BOT should NOT be excluded
|
||||
assert "BOT" not in compiled
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", False)
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_still_excludes_ext_perm_user(self, mock_get_session: MagicMock) -> None:
|
||||
"""EXT_PERM_USER exclusion should still be present."""
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session.execute.return_value.scalar.return_value = 3
|
||||
|
||||
get_used_seats()
|
||||
|
||||
call_args = mock_session.execute.call_args
|
||||
query = call_args[0][0]
|
||||
compiled = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
assert "EXT_PERM_USER" in compiled
|
||||
|
||||
@@ -301,6 +301,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -331,6 +332,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -361,6 +363,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -388,6 +391,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -419,6 +423,7 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -451,6 +456,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
@@ -491,6 +497,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -512,6 +519,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -534,6 +542,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -587,6 +596,7 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -643,6 +653,7 @@ class TestRunModels:
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
@@ -695,6 +706,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
@@ -724,6 +736,7 @@ class TestRunModels:
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.google_kv import get_auth_url
|
||||
from onyx.connectors.google_utils.google_kv import get_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import get_service_account_key
|
||||
from onyx.connectors.google_utils.google_kv import upsert_google_app_cred
|
||||
from onyx.connectors.google_utils.google_kv import upsert_service_account_key
|
||||
from onyx.server.documents.models import GoogleAppCredentials
|
||||
from onyx.server.documents.models import GoogleAppWebCredentials
|
||||
from onyx.server.documents.models import GoogleServiceAccountKey
|
||||
|
||||
|
||||
def _make_app_creds() -> GoogleAppCredentials:
|
||||
return GoogleAppCredentials(
|
||||
web=GoogleAppWebCredentials(
|
||||
client_id="client-id.apps.googleusercontent.com",
|
||||
project_id="test-project",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_secret="secret",
|
||||
redirect_uris=["https://example.com/callback"],
|
||||
javascript_origins=["https://example.com"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _make_service_account_key() -> GoogleServiceAccountKey:
|
||||
return GoogleServiceAccountKey(
|
||||
type="service_account",
|
||||
project_id="test-project",
|
||||
private_key_id="private-key-id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n",
|
||||
client_email="test@test-project.iam.gserviceaccount.com",
|
||||
client_id="123",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url="https://www.googleapis.com/robot/v1/metadata/x509/test",
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_google_app_cred_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_google_app_cred(_make_app_creds(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["web"]["client_id"] == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
def test_upsert_service_account_key_stores_dict(monkeypatch: Any) -> None:
|
||||
stored: dict[str, Any] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored["key"] = key
|
||||
stored["value"] = value
|
||||
stored["encrypt"] = encrypt
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
upsert_service_account_key(_make_service_account_key(), DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert stored["key"] == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
assert stored["encrypt"] is True
|
||||
assert isinstance(stored["value"], dict)
|
||||
assert stored["value"]["project_id"] == "test-project"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_google_app_cred_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload: dict[str, Any] = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
creds = get_google_app_cred(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert creds.web.client_id == "client-id.apps.googleusercontent.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_service_account_key_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
stored_value: object = (
|
||||
_make_service_account_key().model_dump(mode="json")
|
||||
if not legacy_string
|
||||
else _make_service_account_key().model_dump_json()
|
||||
)
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
return stored_value
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
key = get_service_account_key(DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert key.client_email == "test@test-project.iam.gserviceaccount.com"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_string", [False, True])
|
||||
def test_get_auth_url_accepts_dict_and_legacy_string(
|
||||
monkeypatch: Any, legacy_string: bool
|
||||
) -> None:
|
||||
payload = _make_app_creds().model_dump(mode="json")
|
||||
stored_value: object = (
|
||||
payload if not legacy_string else _make_app_creds().model_dump_json()
|
||||
)
|
||||
stored_state: dict[str, object] = {}
|
||||
|
||||
class _StubKvStore:
|
||||
def load(self, key: str) -> object:
|
||||
assert key == KV_GOOGLE_DRIVE_CRED_KEY
|
||||
return stored_value
|
||||
|
||||
def store(self, key: str, value: object, encrypt: bool) -> None:
|
||||
stored_state["key"] = key
|
||||
stored_state["value"] = value
|
||||
stored_state["encrypt"] = encrypt
|
||||
|
||||
class _StubFlow:
|
||||
def authorization_url(self, prompt: str) -> tuple[str, None]:
|
||||
assert prompt == "consent"
|
||||
return "https://accounts.google.com/o/oauth2/auth?state=test-state", None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.get_kv_store", lambda: _StubKvStore()
|
||||
)
|
||||
|
||||
def _from_client_config(
|
||||
_app_config: object, *, scopes: object, redirect_uri: object
|
||||
) -> _StubFlow:
|
||||
del scopes, redirect_uri
|
||||
return _StubFlow()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.google_utils.google_kv.InstalledAppFlow.from_client_config",
|
||||
_from_client_config,
|
||||
)
|
||||
|
||||
auth_url = get_auth_url(42, DocumentSource.GOOGLE_DRIVE)
|
||||
|
||||
assert auth_url.startswith("https://accounts.google.com")
|
||||
assert stored_state["value"] == {"value": "test-state"}
|
||||
assert stored_state["encrypt"] is True
|
||||
@@ -6,7 +6,6 @@ import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import _JIRA_BULK_FETCH_LIMIT
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
@@ -146,29 +145,3 @@ def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
|
||||
|
||||
def test_bulk_fetch_respects_api_batch_limit() -> None:
|
||||
"""Requests to the bulkfetch endpoint never exceed _JIRA_BULK_FETCH_LIMIT IDs."""
|
||||
client = _mock_jira_client()
|
||||
total_issues = _JIRA_BULK_FETCH_LIMIT * 3 + 7
|
||||
all_ids = [str(i) for i in range(total_issues)]
|
||||
|
||||
batch_sizes: list[int] = []
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
batch_sizes.append(len(ids))
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, all_ids)
|
||||
|
||||
assert len(result) == total_issues
|
||||
# keeping this hardcoded because it's the documented limit
|
||||
# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issues/
|
||||
assert all(size <= 100 for size in batch_sizes)
|
||||
assert len(batch_sizes) == 4
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -29,7 +29,6 @@ from onyx.llm.utils import get_max_input_tokens
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -505,7 +505,6 @@ class TestGetLMStudioAvailableModels:
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.api_base = "http://localhost:1234"
|
||||
mock_provider.custom_config = {"LM_STUDIO_API_KEY": "stored-secret"}
|
||||
|
||||
response = {
|
||||
|
||||
@@ -100,39 +100,6 @@ class TestGenerateOllamaDisplayName:
|
||||
result = generate_ollama_display_name("llama3.3:70b")
|
||||
assert "3.3" in result or "3 3" in result # Either format is acceptable
|
||||
|
||||
def test_non_size_tag_shown(self) -> None:
|
||||
"""Test that non-size tags like 'e4b' are included in the display name."""
|
||||
result = generate_ollama_display_name("gemma4:e4b")
|
||||
assert "Gemma" in result
|
||||
assert "4" in result
|
||||
assert "E4B" in result
|
||||
|
||||
def test_size_with_cloud_modifier(self) -> None:
|
||||
"""Test size tag with cloud modifier."""
|
||||
result = generate_ollama_display_name("deepseek-v3.1:671b-cloud")
|
||||
assert "DeepSeek" in result
|
||||
assert "671B" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_size_with_multiple_modifiers(self) -> None:
|
||||
"""Test size tag with multiple modifiers."""
|
||||
result = generate_ollama_display_name("qwen3-vl:235b-instruct-cloud")
|
||||
assert "Qwen" in result
|
||||
assert "235B" in result
|
||||
assert "Instruct" in result
|
||||
assert "Cloud" in result
|
||||
|
||||
def test_quantization_tag_shown(self) -> None:
|
||||
"""Test that quantization tags are included in the display name."""
|
||||
result = generate_ollama_display_name("llama3:q4_0")
|
||||
assert "Llama" in result
|
||||
assert "Q4_0" in result
|
||||
|
||||
def test_cloud_only_tag(self) -> None:
|
||||
"""Test standalone cloud tag."""
|
||||
result = generate_ollama_display_name("glm-4.6:cloud")
|
||||
assert "CLOUD" in result
|
||||
|
||||
|
||||
class TestStripOpenrouterVendorPrefix:
|
||||
"""Tests for OpenRouter vendor prefix stripping."""
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -10,9 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -744,80 +741,3 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -95,9 +95,9 @@ class TestForceAddSearchToolGuard:
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import _construct_tools_impl
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
|
||||
source = inspect.getsource(_construct_tools_impl)
|
||||
source = inspect.getsource(construct_tools)
|
||||
assert (
|
||||
"DISABLE_VECTOR_DB" in source
|
||||
), "construct_tools should reference DISABLE_VECTOR_DB to suppress force-adding SearchTool"
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""Tests for ``ImageGenerationTool._resolve_reference_image_file_ids``.
|
||||
|
||||
The resolver turns the LLM's ``reference_image_file_ids`` argument into a
|
||||
cleaned list of file IDs to hand to ``_load_reference_images``. It trusts
|
||||
the LLM's picks — the LLM can only see file IDs that actually appear in
|
||||
the conversation (via ``[attached image — file_id: <id>]`` tags on user
|
||||
messages and the JSON returned by prior generate_image calls), so we
|
||||
don't re-validate against an allow-list in the tool itself.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD,
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
supports_reference_images: bool = True,
|
||||
max_reference_images: int = 16,
|
||||
) -> ImageGenerationTool:
|
||||
"""Construct a tool with a mock provider so no credentials/network are needed."""
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider"
|
||||
) as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_reference_images = supports_reference_images
|
||||
mock_provider.max_reference_images = max_reference_images
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
return ImageGenerationTool(
|
||||
image_generation_credentials=MagicMock(),
|
||||
tool_id=1,
|
||||
emitter=MagicMock(),
|
||||
model="gpt-image-1",
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveReferenceImageFileIds:
|
||||
def test_unset_returns_empty_plain_generation(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool._resolve_reference_image_file_ids(llm_kwargs={}) == []
|
||||
|
||||
def test_empty_list_is_treated_like_unset(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: []},
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_passes_llm_supplied_ids_through(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["upload-1", "gen-1"]},
|
||||
)
|
||||
# Order preserved — first entry is the primary edit source.
|
||||
assert result == ["upload-1", "gen-1"]
|
||||
|
||||
def test_invalid_shape_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: "not-a-list"},
|
||||
)
|
||||
|
||||
def test_non_string_element_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["ok", 123]},
|
||||
)
|
||||
|
||||
def test_deduplicates_preserving_first_occurrence(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1", "gen-2", "gen-1"]},
|
||||
)
|
||||
assert result == ["gen-1", "gen-2"]
|
||||
|
||||
def test_strips_whitespace_and_skips_empty_strings(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: [" gen-1 ", "", " "]},
|
||||
)
|
||||
assert result == ["gen-1"]
|
||||
|
||||
def test_provider_without_reference_support_raises(self) -> None:
|
||||
tool = _make_tool(supports_reference_images=False)
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1"]},
|
||||
)
|
||||
|
||||
def test_truncates_to_provider_max_preserving_head(self) -> None:
|
||||
"""When the LLM lists more images than the provider allows, keep the
|
||||
HEAD of the list (the primary edit source + earliest extras) rather
|
||||
than the tail, since the LLM put the most important one first."""
|
||||
tool = _make_tool(max_reference_images=2)
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["a", "b", "c", "d"]},
|
||||
)
|
||||
assert result == ["a", "b"]
|
||||
@@ -1,5 +1,10 @@
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
|
||||
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
|
||||
from onyx.tools.tool_runner import _merge_tool_calls
|
||||
|
||||
|
||||
@@ -307,3 +312,62 @@ class TestMergeToolCalls:
|
||||
assert len(result) == 1
|
||||
# String should be converted to list item
|
||||
assert result[0].tool_args["queries"] == ["single_query", "q2"]
|
||||
|
||||
|
||||
class TestImageHistoryExtraction:
|
||||
def test_extracts_image_file_ids_from_json_response(self) -> None:
|
||||
msg = '[{"file_id":"img-1","revised_prompt":"v1"},{"file_id":"img-2","revised_prompt":"v2"}]'
|
||||
assert _extract_image_file_ids_from_tool_response_message(msg) == [
|
||||
"img-1",
|
||||
"img-2",
|
||||
]
|
||||
|
||||
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="generate_image",
|
||||
tool_arguments={"prompt": "test"},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
|
||||
|
||||
def test_ignores_non_image_tool_responses(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_arguments={"queries": ["q"]},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == []
|
||||
|
||||
@@ -58,8 +58,7 @@ SERVICE_ORDER=(
|
||||
validate_template() {
|
||||
local template_file=$1
|
||||
echo "Validating template: $template_file..."
|
||||
aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
if ! aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null; then
|
||||
echo "Error: Validation failed for $template_file. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
@@ -108,13 +107,15 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
# Create temporary parameters file for this template
|
||||
local temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
local temp_params_file
|
||||
temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
|
||||
# Special handling for SubnetIDs parameter if needed
|
||||
if grep -q "SubnetIDs" "$template_file"; then
|
||||
echo "Template uses SubnetIDs parameter, ensuring it's properly formatted..."
|
||||
# Make sure we're passing SubnetIDs as a comma-separated list
|
||||
local subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
local subnet_ids
|
||||
subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
if [ -n "$subnet_ids" ]; then
|
||||
echo "Using SubnetIDs from config: $subnet_ids"
|
||||
else
|
||||
@@ -123,15 +124,13 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
echo "Deploying stack: $stack_name with template: $template_file and generated config from: $CONFIG_FILE..."
|
||||
aws cloudformation deploy \
|
||||
if ! aws cloudformation deploy \
|
||||
--stack-name "$stack_name" \
|
||||
--template-file "$template_file" \
|
||||
--parameter-overrides file://"$temp_params_file" \
|
||||
--capabilities CAPABILITY_IAM CAPABILITY_NAMED_IAM CAPABILITY_AUTO_EXPAND \
|
||||
--region "$AWS_REGION" \
|
||||
--no-cli-auto-prompt > /dev/null
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
--no-cli-auto-prompt > /dev/null; then
|
||||
echo "Error: Deployment failed for $stack_name. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -52,11 +52,9 @@ delete_stack() {
|
||||
--region "$AWS_REGION"
|
||||
|
||||
echo "Waiting for stack $stack_name to be deleted..."
|
||||
aws cloudformation wait stack-delete-complete \
|
||||
if aws cloudformation wait stack-delete-complete \
|
||||
--stack-name "$stack_name" \
|
||||
--region "$AWS_REGION"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
--region "$AWS_REGION"; then
|
||||
echo "Stack $stack_name deleted successfully."
|
||||
sleep 10
|
||||
else
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#!/bin/sh
|
||||
# fill in the template
|
||||
export ONYX_BACKEND_API_HOST="${ONYX_BACKEND_API_HOST:-api_server}"
|
||||
export ONYX_WEB_SERVER_HOST="${ONYX_WEB_SERVER_HOST:-web_server}"
|
||||
@@ -16,12 +17,15 @@ echo "Using web server host: $ONYX_WEB_SERVER_HOST"
|
||||
echo "Using MCP server host: $ONYX_MCP_SERVER_HOST"
|
||||
echo "Using nginx proxy timeouts - connect: ${NGINX_PROXY_CONNECT_TIMEOUT}s, send: ${NGINX_PROXY_SEND_TIMEOUT}s, read: ${NGINX_PROXY_READ_TIMEOUT}s"
|
||||
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME $ONYX_BACKEND_API_HOST $ONYX_WEB_SERVER_HOST $ONYX_MCP_SERVER_HOST $NGINX_PROXY_CONNECT_TIMEOUT $NGINX_PROXY_SEND_TIMEOUT $NGINX_PROXY_READ_TIMEOUT' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
|
||||
|
||||
# Conditionally create MCP server configuration
|
||||
if [ "${MCP_SERVER_ENABLED}" = "True" ] || [ "${MCP_SERVER_ENABLED}" = "true" ]; then
|
||||
echo "MCP server is enabled, creating MCP configuration..."
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$ONYX_MCP_SERVER_HOST' < "/etc/nginx/conf.d/mcp_upstream.conf.inc.template" > /etc/nginx/conf.d/mcp_upstream.conf.inc
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$ONYX_MCP_SERVER_HOST' < "/etc/nginx/conf.d/mcp.conf.inc.template" > /etc/nginx/conf.d/mcp.conf.inc
|
||||
else
|
||||
echo "MCP server is disabled, removing MCP configuration..."
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.40
|
||||
version: 0.4.41
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
{{- if .Values.monitoring.serviceMonitors.enabled }}
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-api
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.api.deploymentLabels.app }}
|
||||
endpoints:
|
||||
- port: api-server-port
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
@@ -74,4 +74,29 @@ spec:
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
|
||||
---
|
||||
apiVersion: monitoring.coreos.com/v1
|
||||
kind: ServiceMonitor
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-celery-worker-heavy
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.monitoring.serviceMonitors.labels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
namespaceSelector:
|
||||
matchNames:
|
||||
- {{ .Release.Namespace }}
|
||||
selector:
|
||||
matchLabels:
|
||||
app: {{ .Values.celery_worker_heavy.deploymentLabels.app }}
|
||||
metrics: "true"
|
||||
endpoints:
|
||||
- port: metrics
|
||||
path: /metrics
|
||||
interval: 30s
|
||||
scrapeTimeout: 10s
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
@@ -264,7 +264,7 @@ monitoring:
|
||||
# The sidecar must be configured with label selector: grafana_dashboard=1
|
||||
enabled: false
|
||||
serviceMonitors:
|
||||
# -- Set to true to deploy ServiceMonitor resources for Celery worker metrics endpoints.
|
||||
# -- Set to true to deploy ServiceMonitor resources for API server and Celery worker metrics endpoints.
|
||||
# Requires the Prometheus Operator CRDs (included in kube-prometheus-stack).
|
||||
# Use `labels` to match your Prometheus CR's serviceMonitorSelector (e.g. release: onyx-monitoring).
|
||||
enabled: false
|
||||
|
||||
@@ -22,6 +22,10 @@ variable "CLI_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-cli"
|
||||
}
|
||||
|
||||
variable "DEVCONTAINER_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-devcontainer"
|
||||
}
|
||||
|
||||
variable "TAG" {
|
||||
default = "latest"
|
||||
}
|
||||
@@ -90,3 +94,16 @@ target "cli" {
|
||||
|
||||
tags = ["${CLI_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "devcontainer" {
|
||||
context = ".devcontainer"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = [
|
||||
"type=registry,ref=${DEVCONTAINER_REPOSITORY}:latest",
|
||||
"type=registry,ref=${DEVCONTAINER_REPOSITORY}:edge",
|
||||
]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${DEVCONTAINER_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.7.3",
|
||||
"onyx-devtools==0.7.4",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
@@ -244,6 +244,54 @@ ods web lint
|
||||
ods web test --watch
|
||||
```
|
||||
|
||||
### `dev` - Devcontainer Management
|
||||
|
||||
Manage the Onyx devcontainer. Also available as `ods dc`.
|
||||
|
||||
Requires the [devcontainer CLI](https://github.com/devcontainers/cli) (`npm install -g @devcontainers/cli`).
|
||||
|
||||
```shell
|
||||
ods dev <subcommand>
|
||||
```
|
||||
|
||||
**Subcommands:**
|
||||
|
||||
- `up` - Start the devcontainer (pulls the image if needed)
|
||||
- `into` - Open a zsh shell inside the running devcontainer
|
||||
- `exec` - Run an arbitrary command inside the devcontainer
|
||||
- `restart` - Remove and recreate the devcontainer
|
||||
- `rebuild` - Pull the latest published image and recreate
|
||||
- `stop` - Stop the running devcontainer
|
||||
|
||||
The devcontainer image is published to `onyxdotapp/onyx-devcontainer` and
|
||||
referenced by tag in `.devcontainer/devcontainer.json` — no local build needed.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Start the devcontainer
|
||||
ods dev up
|
||||
|
||||
# Open a shell
|
||||
ods dev into
|
||||
|
||||
# Run a command
|
||||
ods dev exec -- npm test
|
||||
|
||||
# Restart the container
|
||||
ods dev restart
|
||||
|
||||
# Pull latest image and recreate
|
||||
ods dev rebuild
|
||||
|
||||
# Stop the container
|
||||
ods dev stop
|
||||
|
||||
# Same commands work with the dc alias
|
||||
ods dc up
|
||||
ods dc into
|
||||
```
|
||||
|
||||
### `db` - Database Administration
|
||||
|
||||
Manage PostgreSQL database dumps, restores, and migrations.
|
||||
|
||||
34
tools/ods/cmd/dev.go
Normal file
34
tools/ods/cmd/dev.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewDevCommand creates the parent dev command for devcontainer operations.
|
||||
func NewDevCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "dev",
|
||||
Aliases: []string{"dc"},
|
||||
Short: "Manage the devcontainer",
|
||||
Long: `Manage the Onyx devcontainer.
|
||||
|
||||
Wraps the devcontainer CLI with workspace-aware defaults.
|
||||
|
||||
Commands:
|
||||
up Start the devcontainer
|
||||
into Open a shell inside the running devcontainer
|
||||
exec Run a command inside the devcontainer
|
||||
restart Remove and recreate the devcontainer
|
||||
rebuild Pull the latest image and recreate
|
||||
stop Stop the running devcontainer`,
|
||||
}
|
||||
|
||||
cmd.AddCommand(newDevUpCommand())
|
||||
cmd.AddCommand(newDevIntoCommand())
|
||||
cmd.AddCommand(newDevExecCommand())
|
||||
cmd.AddCommand(newDevRestartCommand())
|
||||
cmd.AddCommand(newDevRebuildCommand())
|
||||
cmd.AddCommand(newDevStopCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
29
tools/ods/cmd/dev_exec.go
Normal file
29
tools/ods/cmd/dev_exec.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newDevExecCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "exec [--] <command> [args...]",
|
||||
Short: "Run a command inside the devcontainer",
|
||||
Long: `Run an arbitrary command inside the running devcontainer.
|
||||
All arguments are treated as positional (flags like -it are passed through).
|
||||
|
||||
Examples:
|
||||
ods dev exec npm test
|
||||
ods dev exec -- ls -la
|
||||
ods dev exec -it echo hello`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
DisableFlagParsing: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(args) > 0 && args[0] == "--" {
|
||||
args = args[1:]
|
||||
}
|
||||
runDevExec(args)
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
51
tools/ods/cmd/dev_into.go
Normal file
51
tools/ods/cmd/dev_into.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
)
|
||||
|
||||
func newDevIntoCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "into",
|
||||
Short: "Open a shell inside the running devcontainer",
|
||||
Long: `Open an interactive zsh shell inside the running devcontainer.
|
||||
|
||||
Examples:
|
||||
ods dev into`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDevExec([]string{"zsh"})
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runDevExec executes "devcontainer exec --workspace-folder <root> <command...>".
|
||||
func runDevExec(command []string) {
|
||||
checkDevcontainerCLI()
|
||||
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find git root: %v", err)
|
||||
}
|
||||
|
||||
args := []string{"exec", "--workspace-folder", root}
|
||||
args = append(args, command...)
|
||||
|
||||
log.Debugf("Running: devcontainer %v", args)
|
||||
|
||||
c := exec.Command("devcontainer", args...)
|
||||
c.Stdout = os.Stdout
|
||||
c.Stderr = os.Stderr
|
||||
c.Stdin = os.Stdin
|
||||
|
||||
if err := c.Run(); err != nil {
|
||||
log.Fatalf("devcontainer exec failed: %v", err)
|
||||
}
|
||||
}
|
||||
41
tools/ods/cmd/dev_rebuild.go
Normal file
41
tools/ods/cmd/dev_rebuild.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newDevRebuildCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rebuild",
|
||||
Short: "Pull the latest devcontainer image and recreate",
|
||||
Long: `Pull the latest devcontainer image and recreate the container.
|
||||
|
||||
Use after the published image has been updated or after changing devcontainer.json.
|
||||
|
||||
Examples:
|
||||
ods dev rebuild`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDevRebuild()
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runDevRebuild() {
|
||||
image := devcontainerImage()
|
||||
|
||||
log.Infof("Pulling %s...", image)
|
||||
pull := exec.Command("docker", "pull", image)
|
||||
pull.Stdout = os.Stdout
|
||||
pull.Stderr = os.Stderr
|
||||
if err := pull.Run(); err != nil {
|
||||
log.Warnf("Failed to pull image (continuing with local copy): %v", err)
|
||||
}
|
||||
|
||||
runDevcontainer("up", []string{"--remove-existing-container"})
|
||||
}
|
||||
23
tools/ods/cmd/dev_restart.go
Normal file
23
tools/ods/cmd/dev_restart.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newDevRestartCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "Remove and recreate the devcontainer",
|
||||
Long: `Remove the existing devcontainer and recreate it.
|
||||
|
||||
Uses the cached image — for a full image rebuild, use "ods dev rebuild".
|
||||
|
||||
Examples:
|
||||
ods dev restart`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDevcontainer("up", []string{"--remove-existing-container"})
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
56
tools/ods/cmd/dev_stop.go
Normal file
56
tools/ods/cmd/dev_stop.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
)
|
||||
|
||||
func newDevStopCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop the running devcontainer",
|
||||
Long: `Stop the running devcontainer.
|
||||
|
||||
Examples:
|
||||
ods dev stop`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDevStop()
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runDevStop() {
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find git root: %v", err)
|
||||
}
|
||||
|
||||
// Find the container by the devcontainer label
|
||||
out, err := exec.Command(
|
||||
"docker", "ps", "-q",
|
||||
"--filter", "label=devcontainer.local_folder="+root,
|
||||
).Output()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find devcontainer: %v", err)
|
||||
}
|
||||
|
||||
containerID := strings.TrimSpace(string(out))
|
||||
if containerID == "" {
|
||||
log.Info("No running devcontainer found")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Stopping devcontainer %s...", containerID)
|
||||
c := exec.Command("docker", "stop", containerID)
|
||||
if err := c.Run(); err != nil {
|
||||
log.Fatalf("Failed to stop devcontainer: %v", err)
|
||||
}
|
||||
log.Info("Devcontainer stopped")
|
||||
}
|
||||
177
tools/ods/cmd/dev_up.go
Normal file
177
tools/ods/cmd/dev_up.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
)
|
||||
|
||||
func newDevUpCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "Start the devcontainer",
|
||||
Long: `Start the devcontainer, pulling the image if needed.
|
||||
|
||||
Examples:
|
||||
ods dev up`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runDevcontainer("up", nil)
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// devcontainerImage reads the image field from .devcontainer/devcontainer.json.
|
||||
func devcontainerImage() string {
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find git root: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(root, ".devcontainer", "devcontainer.json"))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read devcontainer.json: %v", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
Image string `json:"image"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
log.Fatalf("Failed to parse devcontainer.json: %v", err)
|
||||
}
|
||||
if cfg.Image == "" {
|
||||
log.Fatal("No image field in devcontainer.json")
|
||||
}
|
||||
return cfg.Image
|
||||
}
|
||||
|
||||
// checkDevcontainerCLI ensures the devcontainer CLI is installed.
|
||||
func checkDevcontainerCLI() {
|
||||
if _, err := exec.LookPath("devcontainer"); err != nil {
|
||||
log.Fatal("devcontainer CLI is not installed. Install it with: npm install -g @devcontainers/cli")
|
||||
}
|
||||
}
|
||||
|
||||
// ensureDockerSock sets the DOCKER_SOCK environment variable if not already set.
|
||||
// devcontainer.json references ${localEnv:DOCKER_SOCK} for the socket mount.
|
||||
func ensureDockerSock() {
|
||||
if os.Getenv("DOCKER_SOCK") != "" {
|
||||
return
|
||||
}
|
||||
|
||||
sock := detectDockerSock()
|
||||
if err := os.Setenv("DOCKER_SOCK", sock); err != nil {
|
||||
log.Fatalf("Failed to set DOCKER_SOCK: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// detectDockerSock returns the path to the Docker socket on the host.
|
||||
func detectDockerSock() string {
|
||||
// Prefer explicit DOCKER_HOST (strip unix:// prefix if present).
|
||||
if dh := os.Getenv("DOCKER_HOST"); dh != "" {
|
||||
const prefix = "unix://"
|
||||
if len(dh) > len(prefix) && dh[:len(prefix)] == prefix {
|
||||
return dh[len(prefix):]
|
||||
}
|
||||
// Only bare paths (starting with /) are valid socket paths.
|
||||
// Non-unix schemes (e.g. tcp://) can't be bind-mounted.
|
||||
if len(dh) > 0 && dh[0] == '/' {
|
||||
return dh
|
||||
}
|
||||
log.Warnf("DOCKER_HOST=%q is not a unix socket path; falling back to local socket detection", dh)
|
||||
}
|
||||
|
||||
// Linux rootless Docker: $XDG_RUNTIME_DIR/docker.sock
|
||||
if runtime.GOOS == "linux" {
|
||||
if xdg := os.Getenv("XDG_RUNTIME_DIR"); xdg != "" {
|
||||
sock := filepath.Join(xdg, "docker.sock")
|
||||
if _, err := os.Stat(sock); err == nil {
|
||||
return sock
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// macOS Docker Desktop: ~/.docker/run/docker.sock
|
||||
if runtime.GOOS == "darwin" {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
sock := filepath.Join(home, ".docker", "run", "docker.sock")
|
||||
if _, err := os.Stat(sock); err == nil {
|
||||
return sock
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: standard socket path (Linux with standard Docker, macOS symlink)
|
||||
return "/var/run/docker.sock"
|
||||
}
|
||||
|
||||
// worktreeGitMount returns a --mount flag value that makes a git worktree's
|
||||
// .git reference resolve inside the container. In a worktree, .git is a file
|
||||
// containing "gitdir: /path/to/main/.git/worktrees/<name>", so we need the
|
||||
// main repo's .git directory to exist at the same absolute host path inside
|
||||
// the container.
|
||||
//
|
||||
// Returns ("", false) when the workspace is not a worktree.
|
||||
func worktreeGitMount(root string) (string, bool) {
|
||||
dotgit := filepath.Join(root, ".git")
|
||||
info, err := os.Lstat(dotgit)
|
||||
if err != nil || info.IsDir() {
|
||||
return "", false // regular repo or no .git
|
||||
}
|
||||
|
||||
// .git is a file — parse the gitdir path.
|
||||
out, err := exec.Command("git", "-C", root, "rev-parse", "--git-common-dir").Output()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to detect git common dir: %v", err)
|
||||
return "", false
|
||||
}
|
||||
commonDir := strings.TrimSpace(string(out))
|
||||
|
||||
// Resolve to absolute path.
|
||||
if !filepath.IsAbs(commonDir) {
|
||||
commonDir = filepath.Join(root, commonDir)
|
||||
}
|
||||
commonDir, _ = filepath.EvalSymlinks(commonDir)
|
||||
|
||||
mount := fmt.Sprintf("type=bind,source=%s,target=%s", commonDir, commonDir)
|
||||
log.Debugf("Worktree detected — mounting main .git: %s", commonDir)
|
||||
return mount, true
|
||||
}
|
||||
|
||||
// runDevcontainer executes "devcontainer <action> --workspace-folder <root> [extraArgs...]".
|
||||
func runDevcontainer(action string, extraArgs []string) {
|
||||
checkDevcontainerCLI()
|
||||
ensureDockerSock()
|
||||
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find git root: %v", err)
|
||||
}
|
||||
|
||||
args := []string{action, "--workspace-folder", root}
|
||||
if mount, ok := worktreeGitMount(root); ok {
|
||||
args = append(args, "--mount", mount)
|
||||
}
|
||||
args = append(args, extraArgs...)
|
||||
|
||||
log.Debugf("Running: devcontainer %v", args)
|
||||
|
||||
c := exec.Command("devcontainer", args...)
|
||||
c.Stdout = os.Stdout
|
||||
c.Stderr = os.Stderr
|
||||
c.Stdin = os.Stdin
|
||||
|
||||
if err := c.Run(); err != nil {
|
||||
log.Fatalf("devcontainer %s failed: %v", action, err)
|
||||
}
|
||||
}
|
||||
@@ -53,6 +53,7 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewRunCICommand())
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewDesktopCommand())
|
||||
cmd.AddCommand(NewDevCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
16
uv.lock
generated
16
uv.lock
generated
@@ -4511,7 +4511,7 @@ dev = [
|
||||
{ name = "matplotlib", specifier = "==3.10.8" },
|
||||
{ name = "mypy", specifier = "==1.13.0" },
|
||||
{ name = "mypy-extensions", specifier = "==1.0.0" },
|
||||
{ name = "onyx-devtools", specifier = "==0.7.3" },
|
||||
{ name = "onyx-devtools", specifier = "==0.7.4" },
|
||||
{ name = "openapi-generator-cli", specifier = "==7.17.0" },
|
||||
{ name = "pandas-stubs", specifier = "~=2.3.3" },
|
||||
{ name = "pre-commit", specifier = "==3.2.2" },
|
||||
@@ -4554,19 +4554,19 @@ model-server = [
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.7.3"
|
||||
version = "0.7.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/72/64/c75be8ab325896cc64bccd0e1e139a03ce305bf05598967922d380fc4694/onyx_devtools-0.7.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:675e2fdbd8d291fba4b8a6dfcf2bc94c56d22d11f395a9f0d0c3c0e5b39d7f9b", size = 4220613, upload-time = "2026-04-09T00:04:36.624Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/1f/589ff6bd446c4498f5bcdfd2a315709e91fc15edf5440c91ff64cbf0800f/onyx_devtools-0.7.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:bf3993de8ba02d6c2f1ab12b5b9b965e005040b37502f97db8a7d88d9b0cde4b", size = 3897867, upload-time = "2026-04-09T00:04:40.781Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/c0/53c9173eefc13218707282c5b99753960d039684994c3b3caf90ce286094/onyx_devtools-0.7.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6138a94084bed05c674ad210a0bc4006c43bc4384e8eb54d469233de85c72bd7", size = 3762408, upload-time = "2026-04-09T00:04:41.592Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/37/69fadb65112854a596d200f704da94b837817d4dd0f46cb4482dc0309c94/onyx_devtools-0.7.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:90dac91b0cdc32eb8861f6e83545009a34c439fd3c41fc7dd499acd0105b660e", size = 4184427, upload-time = "2026-04-09T00:04:41.525Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/45/91c829ccb45f1a15e7c9641eccc6dd154adb540e03c7dee2a8f28cea24d0/onyx_devtools-0.7.3-py3-none-win_amd64.whl", hash = "sha256:abc68d70bec06e349481beec4b212de28a1a8b7ed6ef3b41daf7093ee10b44f3", size = 4299935, upload-time = "2026-04-09T00:04:40.262Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/30/c5adcb8e3b46b71d8d92c3f9ee0c1d0bc5e2adc9f46e93931f21b36a3ee4/onyx_devtools-0.7.3-py3-none-win_arm64.whl", hash = "sha256:9e4411cadc5e81fabc9ed991402e3b4b40f02800681299c277b2142e5af0dcee", size = 3840228, upload-time = "2026-04-09T00:04:39.708Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/3f/584bb003333b6e6d632b06bbf99d410c7a71adde1711076fd44fe88d966d/onyx_devtools-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6c51d9199ff8ff8fe64a3cfcf77f8170508722b33a1de54c5474be0447b7afa8", size = 4237700, upload-time = "2026-04-09T21:28:20.694Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/04/8c28522d51a66b1bdc997a1c72821122eab23f048459646c6ee62a39f6eb/onyx_devtools-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f64a4cec6d3616b9ca7354e326994882c9ff2cb3f9fc9a44e55f0eb6a6ff1c1c", size = 3912751, upload-time = "2026-04-09T21:28:23.079Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/e6/ae60307cc50064dacb58e003c9a367d5c85118fd89a597abf3de5fd66f0a/onyx_devtools-0.7.4-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:31c7cecaaa329e3f6d53864290bc53fd0b823453c6cfdb8be7931a8925f5c075", size = 3778188, upload-time = "2026-04-09T21:28:23.14Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/d1/5a2789efac7d8f19d30d4d8da1862dd10a16b65d8c9b200542a959094a17/onyx_devtools-0.7.4-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4c44e3c21253ea92127af483155190c14426c729d93e244aedc33875f74d3514", size = 4200526, upload-time = "2026-04-09T21:28:23.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/40/56a467eaa7b78411971898191cf0dc3ee49b7f448d1cfe76cd432f6458d3/onyx_devtools-0.7.4-py3-none-win_amd64.whl", hash = "sha256:6fa2b63b702bc5ecbeed5f9eadec57d61ac5c4a646cf5fbd66ee340f53b7d81c", size = 4319090, upload-time = "2026-04-09T21:28:23.26Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/ef/c866fa8ce1f75e1ac67bc239e767b8944cb1a12a44950986ce57e06db17f/onyx_devtools-0.7.4-py3-none-win_arm64.whl", hash = "sha256:c84cbe6a85474dc9f005f079796cf031e80c4249897432ad9f370cd27f72970a", size = 3857229, upload-time = "2026-04-09T21:28:23.484Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -68,9 +68,7 @@ SCRIPT_DIR="$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
# Run the conversion into a temp file so a failed run doesn't destroy an existing .tsx
|
||||
TMPFILE="${BASE_NAME}.tsx.tmp"
|
||||
bunx @svgr/cli "$SVG_FILE" --typescript --svgo-config "$SVGO_CONFIG" --template "${SCRIPT_DIR}/icon-template.js" > "$TMPFILE"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
if bunx @svgr/cli "$SVG_FILE" --typescript --svgo-config "$SVGO_CONFIG" --template "${SCRIPT_DIR}/icon-template.js" > "$TMPFILE"; then
|
||||
# Verify the temp file has content before replacing the destination
|
||||
if [ ! -s "$TMPFILE" ]; then
|
||||
rm -f "$TMPFILE"
|
||||
@@ -84,16 +82,14 @@ if [ $? -eq 0 ]; then
|
||||
# Using perl for cross-platform compatibility (works on macOS, Linux, Windows with WSL)
|
||||
# Note: perl -i returns 0 even on some failures, so we validate the output
|
||||
|
||||
perl -i -pe 's/<svg/<svg width={size} height={size}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
if ! perl -i -pe 's/<svg/<svg width={size} height={size}/g' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Failed to add width/height attributes" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Icons additionally get stroke="currentColor"
|
||||
if [ "$MODE" = "icon" ]; then
|
||||
perl -i -pe 's/\{\.\.\.props\}/stroke="currentColor" {...props}/g' "${BASE_NAME}.tsx"
|
||||
if [ $? -ne 0 ]; then
|
||||
if ! perl -i -pe 's/\{\.\.\.props\}/stroke="currentColor" {...props}/g' "${BASE_NAME}.tsx"; then
|
||||
echo "Error: Failed to add stroke attribute" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -15,7 +15,6 @@ type InteractiveStatefulVariant =
|
||||
| "select-heavy"
|
||||
| "select-card"
|
||||
| "select-tinted"
|
||||
| "select-input"
|
||||
| "select-filter"
|
||||
| "sidebar-heavy"
|
||||
| "sidebar-light";
|
||||
@@ -36,7 +35,6 @@ interface InteractiveStatefulProps
|
||||
* - `"select-heavy"` — tinted selected background (for list rows, model pickers)
|
||||
* - `"select-card"` — like select-heavy but filled state has a visible background (for cards/larger surfaces)
|
||||
* - `"select-tinted"` — like select-heavy but with a tinted rest background
|
||||
* - `"select-input"` — rests at neutral-00 (matches input bar), hover/open shows neutral-03 + border-01
|
||||
* - `"select-filter"` — like select-tinted for empty/filled; selected state uses inverted tint backgrounds and inverted text (for filter buttons)
|
||||
* - `"sidebar-heavy"` — sidebar navigation items: muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
* - `"sidebar-light"` — sidebar navigation items: uniformly muted across all states (text-02/text-02)
|
||||
|
||||
@@ -350,41 +350,6 @@
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Input — Empty
|
||||
Matches input bar background at rest, tints on hover/open.
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"] {
|
||||
@apply bg-background-neutral-00;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:hover:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="hover"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-04);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"]:active:not(
|
||||
[data-disabled]
|
||||
),
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-interaction="active"]:not(
|
||||
[data-disabled]
|
||||
) {
|
||||
@apply bg-background-neutral-03;
|
||||
--interactive-foreground: var(--text-05);
|
||||
--interactive-foreground-icon: var(--text-05);
|
||||
}
|
||||
.interactive[data-interactive-variant="select-input"][data-interactive-state="empty"][data-disabled] {
|
||||
@apply bg-transparent;
|
||||
--interactive-foreground: var(--text-01);
|
||||
--interactive-foreground-icon: var(--text-01);
|
||||
}
|
||||
|
||||
/* ---------------------------------------------------------------------------
|
||||
Select-Tinted — Filled
|
||||
--------------------------------------------------------------------------- */
|
||||
|
||||
16
web/package-lock.json
generated
16
web/package-lock.json
generated
@@ -47,7 +47,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
@@ -8844,15 +8843,6 @@
|
||||
"react": ">= 16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/copy-to-clipboard": {
|
||||
"version": "3.3.3",
|
||||
"resolved": "https://registry.npmjs.org/copy-to-clipboard/-/copy-to-clipboard-3.3.3.tgz",
|
||||
"integrity": "sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"toggle-selection": "^1.0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/core-js": {
|
||||
"version": "3.46.0",
|
||||
"hasInstallScript": true,
|
||||
@@ -17436,12 +17426,6 @@
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/toggle-selection": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/toggle-selection/-/toggle-selection-1.0.6.tgz",
|
||||
"integrity": "sha512-BiZS+C1OS8g/q2RRbJmy59xpyghNBqrr6k5L/uKBGRsTfxmu3ffiRnd8mlGPUVayg8pvfi5urfnu8TU7DVOkLQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/toposort": {
|
||||
"version": "2.0.2",
|
||||
"license": "MIT"
|
||||
|
||||
@@ -65,7 +65,6 @@
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.0.0",
|
||||
"cookies-next": "^5.1.0",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"date-fns": "^3.6.0",
|
||||
"docx-preview": "^0.3.7",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
|
||||
@@ -73,10 +73,7 @@ export const MemoizedAnchor = memo(
|
||||
: undefined;
|
||||
|
||||
if (!associatedDoc && !associatedSubQuestion) {
|
||||
// Citation not resolved yet (data still streaming) — hide the
|
||||
// raw [[N]](url) link entirely. It will render as a chip once
|
||||
// the citation/document data arrives.
|
||||
return <></>;
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
let icon: React.ReactNode = null;
|
||||
|
||||
@@ -44,8 +44,6 @@ export interface MultiModelPanelProps {
|
||||
errorStackTrace?: string | null;
|
||||
/** Additional error details */
|
||||
errorDetails?: Record<string, any> | null;
|
||||
/** Whether any model is still streaming — disables preferred selection */
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -75,24 +73,19 @@ export default function MultiModelPanel({
|
||||
isRetryable,
|
||||
errorStackTrace,
|
||||
errorDetails,
|
||||
isGenerating,
|
||||
}: MultiModelPanelProps) {
|
||||
const ModelIcon = getModelIcon(provider, modelName);
|
||||
|
||||
const canSelect = !isHidden && !isPreferred && !isGenerating;
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (canSelect) onSelect();
|
||||
}, [canSelect, onSelect]);
|
||||
if (!isHidden && !isPreferred) onSelect();
|
||||
}, [isHidden, isPreferred, onSelect]);
|
||||
|
||||
const header = (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-12 transition-colors",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00",
|
||||
canSelect && "cursor-pointer hover:bg-background-tint-02"
|
||||
"rounded-12",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
@@ -147,7 +140,13 @@ export default function MultiModelPanel({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3 min-w-0 rounded-16">
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
|
||||
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
{header}
|
||||
{errorMessage ? (
|
||||
<div className="p-4">
|
||||
@@ -164,7 +163,6 @@ export default function MultiModelPanel({
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
disableTTS
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
useState,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
useRef,
|
||||
} from "react";
|
||||
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
@@ -117,27 +110,11 @@ export default function MultiModelResponseView({
|
||||
// Refs to each panel wrapper for height animation on deselect
|
||||
const panelElsRef = useRef<Map<number, HTMLDivElement>>(new Map());
|
||||
|
||||
// Tracks which non-preferred panels overflow the preferred height cap.
|
||||
// Measured via useLayoutEffect after maxHeight is applied to the DOM —
|
||||
// ref callbacks fire before layout and can't reliably detect overflow.
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
if (preferredPanelHeight == null || preferredIndex === null) return;
|
||||
const next = new Set<number>();
|
||||
panelElsRef.current.forEach((el, idx) => {
|
||||
if (idx === preferredIndex || hiddenPanels.has(idx)) return;
|
||||
if (el.scrollHeight > el.clientHeight) next.add(idx);
|
||||
});
|
||||
setOverflowingPanels((prev) => {
|
||||
if (prev.size === next.size && Array.from(prev).every((v) => next.has(v)))
|
||||
return prev;
|
||||
return next;
|
||||
});
|
||||
}, [preferredPanelHeight, preferredIndex, hiddenPanels, responses]);
|
||||
|
||||
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
|
||||
if (preferredRoRef.current) {
|
||||
preferredRoRef.current.disconnect();
|
||||
@@ -233,10 +210,8 @@ export default function MultiModelResponseView({
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
|
||||
// Persist preferred response + sync `latestChildNodeId`. Backend's
|
||||
// `set_preferred_response` updates `latest_child_message_id`; if the
|
||||
// frontend chain walk disagrees, the next follow-up fails with
|
||||
// "not on the latest mainline".
|
||||
// Persist preferred response to backend + update local tree so the
|
||||
// input bar unblocks (awaitingPreferredSelection clears).
|
||||
if (parentMessage?.messageId && response.messageId && currentSessionId) {
|
||||
setPreferredResponse(parentMessage.messageId, response.messageId).catch(
|
||||
(err) => console.error("Failed to persist preferred response:", err)
|
||||
@@ -252,7 +227,6 @@ export default function MultiModelResponseView({
|
||||
updated.set(parentMessage.nodeId, {
|
||||
...userMsg,
|
||||
preferredResponseId: response.messageId,
|
||||
latestChildNodeId: response.nodeId,
|
||||
});
|
||||
updateSessionMessageTree(currentSessionId, updated);
|
||||
}
|
||||
@@ -439,7 +413,6 @@ export default function MultiModelResponseView({
|
||||
isRetryable: response.isRetryable,
|
||||
errorStackTrace: response.errorStackTrace,
|
||||
errorDetails: response.errorDetails,
|
||||
isGenerating,
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
@@ -453,7 +426,6 @@ export default function MultiModelResponseView({
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
isGenerating,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -540,6 +512,17 @@ export default function MultiModelResponseView({
|
||||
panelElsRef.current.delete(r.modelIndex);
|
||||
}
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
setOverflowingPanels((prev) => {
|
||||
const had = prev.has(r.modelIndex);
|
||||
if (doesOverflow === had) return prev;
|
||||
const next = new Set(prev);
|
||||
if (doesOverflow) next.add(r.modelIndex);
|
||||
else next.delete(r.modelIndex);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
@@ -550,19 +533,21 @@ export default function MultiModelResponseView({
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
...(overflows
|
||||
? {
|
||||
maskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
WebkitMaskImage:
|
||||
"linear-gradient(to bottom, black calc(100% - 6rem), transparent 100%)",
|
||||
}
|
||||
: {}),
|
||||
position: capped ? "relative" : undefined,
|
||||
}}
|
||||
>
|
||||
<div className={cn(isNonPref && "opacity-50")}>
|
||||
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
|
||||
</div>
|
||||
{overflows && (
|
||||
<div
|
||||
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,25 +1,3 @@
|
||||
/* Map Tailwind Typography prose variables to the project's color tokens.
|
||||
These auto-switch for dark mode via colors.css — no dark: modifier needed.
|
||||
Note: text-05 = highest contrast, text-01 = lowest. */
|
||||
.prose-onyx {
|
||||
--tw-prose-body: var(--text-05);
|
||||
--tw-prose-headings: var(--text-05);
|
||||
--tw-prose-lead: var(--text-04);
|
||||
--tw-prose-links: var(--action-link-05);
|
||||
--tw-prose-bold: var(--text-05);
|
||||
--tw-prose-counters: var(--text-03);
|
||||
--tw-prose-bullets: var(--text-03);
|
||||
--tw-prose-hr: var(--border-02);
|
||||
--tw-prose-quotes: var(--text-04);
|
||||
--tw-prose-quote-borders: var(--border-02);
|
||||
--tw-prose-captions: var(--text-03);
|
||||
--tw-prose-code: var(--text-05);
|
||||
--tw-prose-pre-code: var(--text-04);
|
||||
--tw-prose-pre-bg: var(--background-code-01);
|
||||
--tw-prose-th-borders: var(--border-02);
|
||||
--tw-prose-td-borders: var(--border-01);
|
||||
}
|
||||
|
||||
/* Light mode syntax highlighting (Atom One Light) */
|
||||
.hljs {
|
||||
color: #383a42 !important;
|
||||
@@ -258,102 +236,23 @@ pre[class*="language-"] {
|
||||
scrollbar-color: #4b5563 #1f2937;
|
||||
}
|
||||
|
||||
/* Card wrapper — holds the background, border-radius, padding, and fade overlay.
|
||||
Does NOT scroll — the inner .markdown-table-breakout handles that. */
|
||||
.markdown-table-card {
|
||||
position: relative;
|
||||
background: var(--background-neutral-01);
|
||||
border-radius: 0.5rem;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
/*
|
||||
* Scrollable table container — sits inside the card.
|
||||
* Table breakout container - allows tables to extend beyond their parent's
|
||||
* constrained width to use the full container query width (100cqw).
|
||||
*
|
||||
* Requires an ancestor element with `container-type: inline-size` (@container in Tailwind).
|
||||
*
|
||||
* How the math works:
|
||||
* - width: 100cqw → expand to full container query width
|
||||
* - marginLeft: calc((100% - 100cqw) / 2) → negative margin pulls element left
|
||||
* (100% is parent width, 100cqw is larger, so result is negative)
|
||||
* - paddingLeft/Right: calc((100cqw - 100%) / 2) → padding keeps content aligned
|
||||
* with original position while allowing scroll area to extend
|
||||
*/
|
||||
.markdown-table-breakout {
|
||||
overflow-x: auto;
|
||||
|
||||
/* Always reserve scrollbar height so hover doesn't shift content.
|
||||
Thumb is transparent by default, revealed on hover. */
|
||||
scrollbar-width: thin; /* Firefox — always shows track */
|
||||
scrollbar-color: transparent transparent; /* invisible thumb + track */
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar {
|
||||
height: 6px;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
.markdown-table-breakout::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
}
|
||||
.markdown-table-breakout:hover {
|
||||
scrollbar-color: var(--border-03) transparent; /* Firefox — reveal thumb */
|
||||
}
|
||||
.markdown-table-breakout:hover::-webkit-scrollbar-thumb {
|
||||
background: var(--border-03);
|
||||
}
|
||||
|
||||
/* Fade the right edge via an ::after overlay on the non-scrolling card.
|
||||
Stays pinned while table scrolls; doesn't affect the sticky column. */
|
||||
.markdown-table-card::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
width: 2rem;
|
||||
pointer-events: none;
|
||||
z-index: 2;
|
||||
background: linear-gradient(
|
||||
to right,
|
||||
transparent,
|
||||
var(--background-neutral-01)
|
||||
);
|
||||
border-radius: 0 0.5rem 0.5rem 0;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.markdown-table-card[data-overflows="true"]::after {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Sticky first column — inherits the container's background so it
|
||||
matches regardless of theme or custom wallpaper. */
|
||||
.markdown-table-breakout th:first-child,
|
||||
.markdown-table-breakout td:first-child {
|
||||
position: sticky;
|
||||
left: 0;
|
||||
z-index: 1;
|
||||
padding-left: 0.75rem;
|
||||
background: var(--background-neutral-01);
|
||||
}
|
||||
.markdown-table-breakout th:last-child,
|
||||
.markdown-table-breakout td:last-child {
|
||||
padding-right: 0.75rem;
|
||||
}
|
||||
|
||||
/* Shadow on sticky column when scrolled. Uses an ::after pseudo-element
|
||||
so it isn't clipped by the overflow container or the mask-image fade. */
|
||||
.markdown-table-breakout th:first-child::after,
|
||||
.markdown-table-breakout td:first-child::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: -6px;
|
||||
bottom: 0;
|
||||
width: 6px;
|
||||
pointer-events: none;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s;
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-25);
|
||||
}
|
||||
.dark .markdown-table-breakout th:first-child::after,
|
||||
.dark .markdown-table-breakout td:first-child::after {
|
||||
box-shadow: inset 6px 0 8px -4px var(--alpha-grey-100-60);
|
||||
}
|
||||
.markdown-table-breakout[data-scrolled="true"] th:first-child::after,
|
||||
.markdown-table-breakout[data-scrolled="true"] td:first-child::after {
|
||||
opacity: 1;
|
||||
width: 100cqw;
|
||||
margin-left: calc((100% - 100cqw) / 2);
|
||||
padding-left: calc((100cqw - 100%) / 2);
|
||||
padding-right: calc((100cqw - 100%) / 2);
|
||||
}
|
||||
|
||||
@@ -51,8 +51,6 @@ export interface AgentMessageProps {
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
/** Skip TTS streaming (used in multi-model where voice doesn't apply) */
|
||||
disableTTS?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -101,7 +99,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
disableTTS,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -136,49 +133,32 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
finalAnswerComing
|
||||
);
|
||||
|
||||
// Merge streaming citation/document data with chatState props.
|
||||
// NOTE: citationMap and documentMap from usePacketProcessor are mutated in
|
||||
// place (same object reference), so we use citations.length / documentMap.size
|
||||
// as change-detection proxies to bust the memo cache when new data arrives.
|
||||
// Memoize merged citations separately to avoid creating new object when neither source changed
|
||||
const mergedCitations = useMemo(
|
||||
() => ({
|
||||
...chatState.citations,
|
||||
...citationMap,
|
||||
}),
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[chatState.citations, citationMap, citations.length]
|
||||
[chatState.citations, citationMap]
|
||||
);
|
||||
|
||||
// Merge streaming documentMap into chatState.docs so inline citation chips
|
||||
// can resolve [1] → document even when chatState.docs is empty (multi-model).
|
||||
const mergedDocs = useMemo(() => {
|
||||
const propDocs = chatState.docs ?? [];
|
||||
if (documentMap.size === 0) return propDocs;
|
||||
const seen = new Set(propDocs.map((d) => d.document_id));
|
||||
const extras = Array.from(documentMap.values()).filter(
|
||||
(d) => !seen.has(d.document_id)
|
||||
);
|
||||
return extras.length > 0 ? [...propDocs, ...extras] : propDocs;
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [chatState.docs, documentMap, documentMap.size]);
|
||||
|
||||
// Create a chatState that uses streaming citations and documents for immediate rendering.
|
||||
// Memoized with granular dependencies to prevent cascading re-renders.
|
||||
// Create a chatState that uses streaming citations for immediate rendering
|
||||
// This merges the prop citations with streaming citations, preferring streaming ones
|
||||
// Memoized with granular dependencies to prevent cascading re-renders
|
||||
// Note: chatState object is recreated upstream on every render, so we depend on
|
||||
// individual fields instead of the whole object for proper memoization.
|
||||
// individual fields instead of the whole object for proper memoization
|
||||
const effectiveChatState = useMemo<FullChatState>(
|
||||
() => ({
|
||||
...chatState,
|
||||
citations: mergedCitations,
|
||||
docs: mergedDocs,
|
||||
}),
|
||||
[
|
||||
chatState.agent,
|
||||
chatState.docs,
|
||||
chatState.setPresentingDocument,
|
||||
chatState.overriddenModel,
|
||||
chatState.researchType,
|
||||
mergedCitations,
|
||||
mergedDocs,
|
||||
]
|
||||
);
|
||||
|
||||
@@ -222,9 +202,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// Multi-model: skip TTS entirely
|
||||
if (disableTTS) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
@@ -328,7 +305,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRenderComplete();
|
||||
}
|
||||
}}
|
||||
animate={!stopPacketSeen}
|
||||
animate={false}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
|
||||
@@ -59,6 +59,7 @@ function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useCallback, useEffect, useRef, useMemo, JSX } from "react";
|
||||
import React, { useCallback, useMemo, JSX } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
@@ -17,79 +17,10 @@ import { transformLinkUri, cn } from "@/lib/utils";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
|
||||
/** Table wrapper that detects horizontal overflow and shows a fade + scrollbar. */
|
||||
interface ScrollableTableProps
|
||||
extends React.TableHTMLAttributes<HTMLTableElement> {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export function ScrollableTable({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: ScrollableTableProps) {
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
const wrapRef = useRef<HTMLDivElement>(null);
|
||||
const tableRef = useRef<HTMLTableElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const el = scrollRef.current;
|
||||
const wrap = wrapRef.current;
|
||||
const table = tableRef.current;
|
||||
if (!el || !wrap) return;
|
||||
|
||||
const check = () => {
|
||||
const overflows = el.scrollWidth > el.clientWidth;
|
||||
const atEnd = el.scrollLeft + el.clientWidth >= el.scrollWidth - 2;
|
||||
wrap.dataset.overflows = overflows && !atEnd ? "true" : "false";
|
||||
el.dataset.scrolled = el.scrollLeft > 0 ? "true" : "false";
|
||||
};
|
||||
|
||||
check();
|
||||
el.addEventListener("scroll", check, { passive: true });
|
||||
// Observe both the scroll container (parent resize) and the table
|
||||
// itself (content growth during streaming).
|
||||
const ro = new ResizeObserver(check);
|
||||
ro.observe(el);
|
||||
if (table) ro.observe(table);
|
||||
|
||||
return () => {
|
||||
el.removeEventListener("scroll", check);
|
||||
ro.disconnect();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div ref={wrapRef} className="markdown-table-card">
|
||||
<div ref={scrollRef} className="markdown-table-breakout">
|
||||
<table
|
||||
ref={tableRef}
|
||||
className={cn(
|
||||
className,
|
||||
"min-w-full !my-0 [&_th]:whitespace-nowrap [&_td]:whitespace-nowrap"
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes content for markdown rendering by handling code blocks and LaTeX
|
||||
*/
|
||||
export const processContent = (content: string): string => {
|
||||
// Strip incomplete citation links at the end of streaming content.
|
||||
// During typewriter animation, [[N]](url) is revealed character by character.
|
||||
// ReactMarkdown can't parse an incomplete link and renders it as raw text.
|
||||
// This regex removes any trailing partial citation pattern so only complete
|
||||
// links are passed to the markdown parser.
|
||||
content = content.replace(/\[\[\d+\]\]\([^)]*$/, "");
|
||||
// Also strip a lone [[ or [[N] or [[N]] at the very end (before the URL part arrives)
|
||||
content = content.replace(/\[\[(?:\d+\]?\]?)?$/, "");
|
||||
|
||||
const codeBlockRegex = /```(\w*)\n[\s\S]*?```|```[\s\S]*?$/g;
|
||||
const matches = content.match(codeBlockRegex);
|
||||
|
||||
@@ -196,9 +127,11 @@ export const useMarkdownComponents = (
|
||||
},
|
||||
table: ({ node, className, children, ...props }: any) => {
|
||||
return (
|
||||
<ScrollableTable className={className} {...props}>
|
||||
{children}
|
||||
</ScrollableTable>
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...props}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
code: ({ node, className, children }: any) => {
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import ReactMarkdown, { Components } from "react-markdown";
|
||||
import type { PluggableList } from "unified";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeHighlight from "rehype-highlight";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { useTypewriter } from "@/hooks/useTypewriter";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
import {
|
||||
ChatPacket,
|
||||
PacketType,
|
||||
@@ -16,22 +8,16 @@ import {
|
||||
} from "../../../services/streamingModels";
|
||||
import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { processContent } from "../markdownUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import {
|
||||
MemoizedAnchor,
|
||||
MemoizedParagraph,
|
||||
} from "@/app/app/message/MemoizedTextComponents";
|
||||
import { extractCodeText } from "@/app/app/message/codeUtils";
|
||||
import { CodeBlock } from "@/app/app/message/CodeBlock";
|
||||
import { InMessageImage } from "@/app/app/components/files/images/InMessageImage";
|
||||
import { extractChatImageFileId } from "@/app/app/components/files/images/utils";
|
||||
import { cn, transformLinkUri } from "@/lib/utils";
|
||||
|
||||
/** Maps a visible-char count to a markdown index (skips formatting chars,
|
||||
* extends to word boundary). Used by the voice-sync reveal path only. */
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
@@ -39,11 +25,13 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
@@ -60,6 +48,7 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
@@ -71,15 +60,8 @@ function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Cheap streaming plugins (gfm only) → cheap per-frame parse. Full
|
||||
// pipeline flips in once, at the end, for syntax highlighting + math.
|
||||
const STREAMING_REMARK_PLUGINS: PluggableList = [remarkGfm];
|
||||
const STREAMING_REHYPE_PLUGINS: PluggableList = [];
|
||||
const FULL_REMARK_PLUGINS: PluggableList = [
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: true }],
|
||||
];
|
||||
const FULL_REHYPE_PLUGINS: PluggableList = [rehypeHighlight, rehypeKatex];
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
|
||||
export const MessageTextRenderer: MessageRenderer<
|
||||
ChatPacket,
|
||||
@@ -96,17 +78,19 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
stopReason,
|
||||
children,
|
||||
}) => {
|
||||
// If we're animating and the final answer is already complete, show more packets initially
|
||||
const initialPacketCount = animate
|
||||
? packets.length > 0
|
||||
? 1 // Otherwise start with 1 packet
|
||||
: 0
|
||||
: -1; // Show all if not animating
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Timeout guard: if TTS doesn't start within 5s of voice sync
|
||||
// activating, fall back to normal streaming. Prevents permanent
|
||||
// content suppression when the voice WebSocket fails to connect.
|
||||
const [voiceSyncTimedOut, setVoiceSyncTimedOut] = useState(false);
|
||||
const voiceSyncTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null
|
||||
);
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
@@ -115,6 +99,7 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
.map((packet) => {
|
||||
if (
|
||||
@@ -129,74 +114,117 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
!voiceSyncTimedOut &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Start/clear the timeout when voice sync activates/deactivates.
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
useEffect(() => {
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
if (!voiceSyncTimeoutRef.current) {
|
||||
voiceSyncTimeoutRef.current = setTimeout(() => {
|
||||
setVoiceSyncTimedOut(true);
|
||||
}, 5000);
|
||||
}
|
||||
} else {
|
||||
// TTS started or sync deactivated — clear timeout
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
if (voiceSyncTimedOut && !autoPlayback) setVoiceSyncTimedOut(false);
|
||||
if (!animate) {
|
||||
setDisplayedPacketCount(-1); // Show all packets
|
||||
return;
|
||||
}
|
||||
return () => {
|
||||
if (voiceSyncTimeoutRef.current) {
|
||||
clearTimeout(voiceSyncTimeoutRef.current);
|
||||
voiceSyncTimeoutRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
voiceSyncTimedOut,
|
||||
]);
|
||||
|
||||
// Normal streaming hands full text to the typewriter. Voice-sync
|
||||
// paths pre-slice and bypass. If shouldUseAutoPlaybackSync is false
|
||||
// (including after the 5s timeout), all paths fall through to fullContent.
|
||||
if (displayedPacketCount >= 0 && displayedPacketCount < packets.length) {
|
||||
const timer = setTimeout(() => {
|
||||
setDisplayedPacketCount((prev) => Math.min(prev + 1, packets.length));
|
||||
}, PACKET_DELAY_MS);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [animate, displayedPacketCount, packets.length]);
|
||||
|
||||
// Reset displayed count when packet array changes significantly (e.g., new message)
|
||||
useEffect(() => {
|
||||
if (animate && packets.length < displayedPacketCount) {
|
||||
const resetCount = isFinalAnswerComplete(packets)
|
||||
? Math.min(10, packets.length)
|
||||
: packets.length > 0
|
||||
? 1
|
||||
: 0;
|
||||
setDisplayedPacketCount(resetCount);
|
||||
}
|
||||
}, [animate, packets.length, displayedPacketCount]);
|
||||
|
||||
// Only mark as complete when all packets are received AND displayed
|
||||
useEffect(() => {
|
||||
if (isFinalAnswerComplete(packets)) {
|
||||
// If animating, wait until all packets are displayed
|
||||
if (
|
||||
animate &&
|
||||
displayedPacketCount >= 0 &&
|
||||
displayedPacketCount < packets.length
|
||||
) {
|
||||
return;
|
||||
}
|
||||
onComplete();
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
|
||||
// Get content based on displayed packet count or audio progress
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
return fullContent;
|
||||
// Standard behavior when auto-playback is off
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
if (
|
||||
packet.obj.type === PacketType.MESSAGE_DELTA ||
|
||||
packet.obj.type === PacketType.MESSAGE_START
|
||||
) {
|
||||
return packet.obj.content;
|
||||
}
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [
|
||||
shouldUseAutoPlaybackSync,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isAudioSyncActive,
|
||||
revealedCharCount,
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Monotonic guard for voice sync + freeze on user cancel.
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel during live streaming, freeze at exactly what was already
|
||||
// visible to prevent flicker. On history reload (animate=false), the ref
|
||||
// starts empty so we must use computedContent directly.
|
||||
if (wasUserCancelled && animate) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
@@ -214,10 +242,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [
|
||||
computedContent,
|
||||
@@ -227,6 +258,7 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
animate,
|
||||
]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
@@ -238,128 +270,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const isStreamingAnimationEnabled =
|
||||
animate &&
|
||||
!shouldUseAutoPlaybackSync &&
|
||||
stopReason !== StopReason.USER_CANCELLED;
|
||||
|
||||
const isStreamFinished = isFinalAnswerComplete(packets);
|
||||
|
||||
const displayedContent = useTypewriter(content, isStreamingAnimationEnabled);
|
||||
|
||||
// One-way signal: stream done AND typewriter caught up. Do NOT derive
|
||||
// this from "typewriter currently behind" — it oscillates mid-stream
|
||||
// between packet bursts and would thrash the plugin pipeline.
|
||||
const streamFullyDisplayed =
|
||||
isStreamFinished && displayedContent.length >= content.length;
|
||||
|
||||
// Fire onComplete exactly once per mount. `onComplete` is an inline
|
||||
// arrow in AgentMessage so its identity changes on every parent render;
|
||||
// without this guard, each new identity would re-fire the effect once
|
||||
// `streamFullyDisplayed` is true.
|
||||
const onCompleteFiredRef = useRef(false);
|
||||
useEffect(() => {
|
||||
if (streamFullyDisplayed && !onCompleteFiredRef.current) {
|
||||
onCompleteFiredRef.current = true;
|
||||
onComplete();
|
||||
}
|
||||
}, [streamFullyDisplayed, onComplete]);
|
||||
|
||||
const processedContent = useMemo(
|
||||
() => processContent(displayedContent),
|
||||
[displayedContent]
|
||||
);
|
||||
|
||||
// Stable-identity components for ReactMarkdown. Dynamic data (`state`,
|
||||
// `processedContent`) flows through refs so the callback identities
|
||||
// never change — otherwise every typewriter tick would invalidate
|
||||
// React reconciliation on the markdown subtree.
|
||||
const stateRef = useRef(state);
|
||||
stateRef.current = state;
|
||||
const processedContentRef = useRef(processedContent);
|
||||
processedContentRef.current = processedContent;
|
||||
|
||||
const markdownComponents = useMemo<Components>(
|
||||
() => ({
|
||||
a: ({ href, children }) => {
|
||||
const s = stateRef.current;
|
||||
const imageFileId = extractChatImageFileId(
|
||||
href,
|
||||
String(children ?? "")
|
||||
);
|
||||
if (imageFileId) {
|
||||
return (
|
||||
<InMessageImage
|
||||
fileId={imageFileId}
|
||||
fileName={String(children ?? "")}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MemoizedAnchor
|
||||
updatePresentingDocument={s?.setPresentingDocument || (() => {})}
|
||||
docs={s?.docs || []}
|
||||
userFiles={s?.userFiles || []}
|
||||
citations={s?.citations}
|
||||
href={href}
|
||||
>
|
||||
{children}
|
||||
</MemoizedAnchor>
|
||||
);
|
||||
},
|
||||
p: ({ children }) => (
|
||||
<MemoizedParagraph className="font-main-content-body">
|
||||
{children}
|
||||
</MemoizedParagraph>
|
||||
),
|
||||
pre: ({ children }) => <>{children}</>,
|
||||
b: ({ className, children }) => (
|
||||
<span className={className}>{children}</span>
|
||||
),
|
||||
ul: ({ className, children, ...rest }) => (
|
||||
<ul className={className} {...rest}>
|
||||
{children}
|
||||
</ul>
|
||||
),
|
||||
ol: ({ className, children, ...rest }) => (
|
||||
<ol className={className} {...rest}>
|
||||
{children}
|
||||
</ol>
|
||||
),
|
||||
li: ({ className, children, ...rest }) => (
|
||||
<li className={className} {...rest}>
|
||||
{children}
|
||||
</li>
|
||||
),
|
||||
table: ({ className, children, ...rest }) => (
|
||||
<div className="markdown-table-breakout">
|
||||
<table className={cn(className, "min-w-full")} {...rest}>
|
||||
{children}
|
||||
</table>
|
||||
</div>
|
||||
),
|
||||
code: ({ node, className, children }) => {
|
||||
const codeText = extractCodeText(
|
||||
node,
|
||||
processedContentRef.current,
|
||||
children
|
||||
);
|
||||
return (
|
||||
<CodeBlock className={className} codeText={codeText}>
|
||||
{children}
|
||||
</CodeBlock>
|
||||
);
|
||||
},
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
@@ -375,16 +292,16 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
displayedContent.length > 0 &&
|
||||
((isStreamingAnimationEnabled && !streamFullyDisplayed) ||
|
||||
(!isStreamingAnimationEnabled && !stopPacketSeen) ||
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
|
||||
// `[*]() ` is rendered by the anchor component as an inline blinking
|
||||
// caret, keeping it flush with the trailing character.
|
||||
const markdownInput = shouldShowCursor
|
||||
? processedContent + " [*]() "
|
||||
: processedContent;
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
|
||||
return children([
|
||||
{
|
||||
@@ -395,26 +312,8 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : displayedContent.length > 0 ? (
|
||||
<div dir="auto">
|
||||
<ReactMarkdown
|
||||
className="prose prose-onyx font-main-content-body max-w-full"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REMARK_PLUGINS
|
||||
: STREAMING_REMARK_PLUGINS
|
||||
}
|
||||
rehypePlugins={
|
||||
streamFullyDisplayed
|
||||
? FULL_REHYPE_PLUGINS
|
||||
: STREAMING_REHYPE_PLUGINS
|
||||
}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{markdownInput}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -34,8 +34,7 @@ export const PROVIDERS: ProviderConfig[] = [
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
export interface BuildLlmSelection {
|
||||
providerName: string; // e.g., "build-mode-anthropic" (LLMProviderDescriptor.name)
|
||||
provider: string; // e.g., "anthropic"
|
||||
modelName: string; // e.g., "claude-opus-4-7"
|
||||
modelName: string; // e.g., "claude-opus-4-6"
|
||||
}
|
||||
|
||||
// Priority order for smart default LLM selection
|
||||
const LLM_SELECTION_PRIORITY = [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-7" },
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openrouter", modelName: "minimax/minimax-m2.1" },
|
||||
] as const;
|
||||
@@ -63,11 +63,10 @@ export function getDefaultLlmSelection(
|
||||
export const RECOMMENDED_BUILD_MODELS = {
|
||||
preferred: {
|
||||
provider: "anthropic",
|
||||
modelName: "claude-opus-4-7",
|
||||
displayName: "Claude Opus 4.7",
|
||||
modelName: "claude-opus-4-6",
|
||||
displayName: "Claude Opus 4.6",
|
||||
},
|
||||
alternatives: [
|
||||
{ provider: "anthropic", modelName: "claude-opus-4-6" },
|
||||
{ provider: "anthropic", modelName: "claude-sonnet-4-6" },
|
||||
{ provider: "openai", modelName: "gpt-5.2" },
|
||||
{ provider: "openai", modelName: "gpt-5.1-codex" },
|
||||
@@ -149,8 +148,7 @@ export const BUILD_MODE_PROVIDERS: BuildModeProvider[] = [
|
||||
providerName: "anthropic",
|
||||
recommended: true,
|
||||
models: [
|
||||
{ name: "claude-opus-4-7", label: "Claude Opus 4.7", recommended: true },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6" },
|
||||
{ name: "claude-opus-4-6", label: "Claude Opus 4.6", recommended: true },
|
||||
{ name: "claude-sonnet-4-6", label: "Claude Sonnet 4.6" },
|
||||
],
|
||||
apiKeyPlaceholder: "sk-ant-...",
|
||||
|
||||
@@ -320,7 +320,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: submittedMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
@@ -332,7 +332,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: chatMessage,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
additionalContext,
|
||||
selectedModels,
|
||||
});
|
||||
@@ -370,16 +370,10 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
onSubmit({
|
||||
message: lastUserMsg.message,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
deepResearch: deepResearchEnabled && !multiModel.isMultiModelActive,
|
||||
deepResearch: deepResearchEnabled,
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
});
|
||||
}, [
|
||||
messageHistory,
|
||||
onSubmit,
|
||||
currentMessageFiles,
|
||||
deepResearchEnabled,
|
||||
multiModel.isMultiModelActive,
|
||||
]);
|
||||
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
|
||||
|
||||
// Start a new chat session in the side panel
|
||||
const handleNewChat = useCallback(() => {
|
||||
@@ -522,7 +516,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"w-full flex flex-col",
|
||||
!isSidePanel && "max-w-[var(--app-page-main-content-width)]"
|
||||
!isSidePanel &&
|
||||
"max-w-[var(--app-page-main-content-width)] px-4"
|
||||
)}
|
||||
>
|
||||
{hasMessages && liveAgent && !llmManager.isLoadingProviders && (
|
||||
@@ -540,7 +535,6 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
ref={chatInputBarRef}
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
toggleDeepResearch={toggleDeepResearch}
|
||||
isMultiModelActive={multiModel.isMultiModelActive}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
initialMessage={message}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SourceIcon } from "./SourceIcon";
|
||||
import { useState } from "react";
|
||||
import { GithubIcon, OnyxIcon } from "./icons/icons";
|
||||
import { OnyxIcon } from "./icons/icons";
|
||||
|
||||
export function WebResultIcon({
|
||||
url,
|
||||
@@ -23,8 +23,6 @@ export function WebResultIcon({
|
||||
<>
|
||||
{hostname.includes("onyx.app") ? (
|
||||
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
|
||||
) : hostname === "github.com" || hostname.endsWith(".github.com") ? (
|
||||
<GithubIcon size={size} />
|
||||
) : !error ? (
|
||||
<img
|
||||
className="my-0 rounded-full py-0"
|
||||
|
||||
@@ -46,7 +46,6 @@ import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
import gitbookDarkIcon from "@public/GitBookDark.png";
|
||||
import gitbookLightIcon from "@public/GitBookLight.png";
|
||||
import githubDarkIcon from "@public/GithubDarkMode.png";
|
||||
import githubLightIcon from "@public/Github.png";
|
||||
import gongIcon from "@public/Gong.png";
|
||||
import googleIcon from "@public/Google.png";
|
||||
@@ -856,7 +855,7 @@ export const GitbookIcon = createLogoIcon(gitbookDarkIcon, {
|
||||
darkSrc: gitbookLightIcon,
|
||||
});
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
darkSrc: githubDarkIcon,
|
||||
monochromatic: true,
|
||||
});
|
||||
export const GitlabIcon = createLogoIcon(gitlabIcon);
|
||||
export const GmailIcon = createLogoIcon(gmailIcon);
|
||||
|
||||
@@ -644,7 +644,6 @@ export default function useChatController({
|
||||
});
|
||||
node.modelDisplayName = model.displayName;
|
||||
node.overridden_model = model.modelName;
|
||||
node.is_generating = true;
|
||||
return node;
|
||||
});
|
||||
}
|
||||
@@ -712,13 +711,6 @@ export default function useChatController({
|
||||
? selectedModels?.map((m) => m.displayName) ?? []
|
||||
: [];
|
||||
|
||||
// rAF-batched flush state. One Zustand write per frame instead of
|
||||
// one per packet.
|
||||
const dirtyModelIndices = new Set<number>();
|
||||
let singleModelDirty = false;
|
||||
let userNodeDirty = false;
|
||||
let pendingFlush = false;
|
||||
|
||||
/** Build a non-errored multi-model assistant node for upsert. */
|
||||
function buildAssistantNodeUpdate(
|
||||
idx: number,
|
||||
@@ -748,124 +740,16 @@ export default function useChatController({
|
||||
};
|
||||
}
|
||||
|
||||
/** With `onlyDirty`, rebuilds only those model nodes — unchanged
|
||||
* siblings keep their stable Message ref so React memo short-circuits. */
|
||||
function buildNonErroredNodes(
|
||||
overrides?: Partial<Message>,
|
||||
onlyDirty?: Set<number> | null
|
||||
): Message[] {
|
||||
/** Build updated nodes for all non-errored models. */
|
||||
function buildNonErroredNodes(overrides?: Partial<Message>): Message[] {
|
||||
const nodes: Message[] = [];
|
||||
for (let idx = 0; idx < initialAssistantNodes.length; idx++) {
|
||||
if (erroredModelIndices.has(idx)) continue;
|
||||
if (onlyDirty && !onlyDirty.has(idx)) continue;
|
||||
nodes.push(buildAssistantNodeUpdate(idx, overrides));
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
/** Flush accumulated packet state into the tree as one Zustand
|
||||
* update. No-op when nothing is pending. */
|
||||
function flushPendingUpdates() {
|
||||
if (!pendingFlush) return;
|
||||
pendingFlush = false;
|
||||
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
let messagesToUpsert: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
if (dirtyModelIndices.size === 0 && !userNodeDirty) return;
|
||||
|
||||
const dirtySnapshot = new Set(dirtyModelIndices);
|
||||
dirtyModelIndices.clear();
|
||||
const dirtyNodes = buildNonErroredNodes(undefined, dirtySnapshot);
|
||||
|
||||
if (userNodeDirty) {
|
||||
userNodeDirty = false;
|
||||
// Read current user node to preserve childrenNodeIds
|
||||
// (initialUserNode's are stale from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsert = [updatedUserNode, ...dirtyNodes];
|
||||
} else {
|
||||
messagesToUpsert = dirtyNodes;
|
||||
}
|
||||
|
||||
if (messagesToUpsert.length === 0) return;
|
||||
} else {
|
||||
if (!singleModelDirty) return;
|
||||
singleModelDirty = false;
|
||||
|
||||
messagesToUpsert = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsert,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
}
|
||||
|
||||
/** Awaits next animation frame (or a setTimeout fallback when the
|
||||
* tab is hidden — rAF is paused in background tabs, which would
|
||||
* otherwise hang the stream loop here), then flushes. Aligns
|
||||
* React updates with the paint cycle when visible. */
|
||||
function flushViaRAF(): Promise<void> {
|
||||
return new Promise<void>((resolve) => {
|
||||
let done = false;
|
||||
const flush = () => {
|
||||
if (done) return;
|
||||
done = true;
|
||||
flushPendingUpdates();
|
||||
resolve();
|
||||
};
|
||||
requestAnimationFrame(flush);
|
||||
// Fallback for hidden tabs where rAF is paused. Throttled to
|
||||
// ~1s by browsers, matching the previous setTimeout(500) cadence.
|
||||
setTimeout(flush, 100);
|
||||
});
|
||||
}
|
||||
|
||||
let streamSucceeded = false;
|
||||
|
||||
try {
|
||||
@@ -952,12 +836,7 @@ export default function useChatController({
|
||||
await delay(50);
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
if (stack.isEmpty()) {
|
||||
// Flush the burst on the next paint, or idle briefly.
|
||||
if (pendingFlush) {
|
||||
await flushViaRAF();
|
||||
} else {
|
||||
await delay(0.5);
|
||||
}
|
||||
await delay(0.5);
|
||||
}
|
||||
|
||||
if (!stack.isEmpty() && !controller.signal.aborted) {
|
||||
@@ -981,7 +860,6 @@ export default function useChatController({
|
||||
if ((packet as MessageResponseIDInfo).user_message_id) {
|
||||
newUserMessageId = (packet as MessageResponseIDInfo)
|
||||
.user_message_id;
|
||||
userNodeDirty = true;
|
||||
|
||||
// Track extension queries in PostHog (reuses isExtension/extensionContext from above)
|
||||
if (isExtension) {
|
||||
@@ -1020,8 +898,6 @@ export default function useChatController({
|
||||
modelDisplayNames[mi] = slot.model_name;
|
||||
}
|
||||
}
|
||||
userNodeDirty = true;
|
||||
pendingFlush = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1033,7 +909,6 @@ export default function useChatController({
|
||||
!files.some((existingFile) => existingFile.id === newFile.id)
|
||||
);
|
||||
files = files.concat(newUserFiles);
|
||||
if (newUserFiles.length > 0) userNodeDirty = true;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "file_ids")) {
|
||||
@@ -1053,20 +928,15 @@ export default function useChatController({
|
||||
|
||||
// In multi-model mode, route per-model errors to the specific model's
|
||||
// node instead of killing the entire stream. Other models keep streaming.
|
||||
if (isMultiModel) {
|
||||
// Multi-model: isolate the error to its panel. Never throw
|
||||
// or set global error state — other models keep streaming.
|
||||
const errorModelIndex = streamingError.details?.model_index as
|
||||
| number
|
||||
| undefined;
|
||||
if (isMultiModel && streamingError.details?.model_index != null) {
|
||||
const errorModelIndex = streamingError.details
|
||||
.model_index as number;
|
||||
if (
|
||||
errorModelIndex != null &&
|
||||
errorModelIndex >= 0 &&
|
||||
errorModelIndex < initialAssistantNodes.length
|
||||
) {
|
||||
const errorNode = initialAssistantNodes[errorModelIndex]!;
|
||||
erroredModelIndices.add(errorModelIndex);
|
||||
dirtyModelIndices.delete(errorModelIndex);
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: [
|
||||
{
|
||||
@@ -1093,15 +963,8 @@ export default function useChatController({
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
} else {
|
||||
// Error without model_index in multi-model — can't route
|
||||
// to a specific panel. Log and continue; the stream loop
|
||||
// stays alive for other models.
|
||||
console.warn(
|
||||
"Multi-model error without model_index:",
|
||||
streamingError.error
|
||||
);
|
||||
}
|
||||
// Skip the normal per-packet upsert — we already upserted the error node
|
||||
continue;
|
||||
} else {
|
||||
// Single-model: kill the stream
|
||||
@@ -1130,21 +993,19 @@ export default function useChatController({
|
||||
|
||||
if (isMultiModel) {
|
||||
// Multi-model: route packet by placement.model_index.
|
||||
// OverallStop (type "stop") has model_index=null — it's a
|
||||
// global terminal packet that must be delivered to ALL
|
||||
// models so each panel's AgentMessage sees the stop and
|
||||
// exits "Thinking..." state.
|
||||
// OverallStop (type "stop") has model_index=null — it's a global
|
||||
// terminal packet that must be delivered to ALL models so each
|
||||
// panel's AgentMessage sees the stop and exits "Thinking..." state.
|
||||
const isGlobalStop =
|
||||
packetObj.type === "stop" &&
|
||||
typedPacket.placement?.model_index == null;
|
||||
|
||||
if (isGlobalStop) {
|
||||
for (let mi = 0; mi < packetsPerModel.length; mi++) {
|
||||
// Mutated in place — change detection uses packetCount, not array identity.
|
||||
packetsPerModel[mi]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(mi)) {
|
||||
dirtyModelIndices.add(mi);
|
||||
}
|
||||
packetsPerModel[mi] = [
|
||||
...packetsPerModel[mi]!,
|
||||
typedPacket,
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1154,10 +1015,10 @@ export default function useChatController({
|
||||
modelIndex >= 0 &&
|
||||
modelIndex < packetsPerModel.length
|
||||
) {
|
||||
packetsPerModel[modelIndex]!.push(typedPacket);
|
||||
if (!erroredModelIndices.has(modelIndex)) {
|
||||
dirtyModelIndices.add(modelIndex);
|
||||
}
|
||||
packetsPerModel[modelIndex] = [
|
||||
...packetsPerModel[modelIndex]!,
|
||||
typedPacket,
|
||||
];
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1187,7 +1048,6 @@ export default function useChatController({
|
||||
// Single-model
|
||||
packets.push(typedPacket);
|
||||
packetsVersion++;
|
||||
singleModelDirty = true;
|
||||
|
||||
if (packetObj.type === "citation_info") {
|
||||
const citationInfo = packetObj as {
|
||||
@@ -1214,16 +1074,73 @@ export default function useChatController({
|
||||
console.warn("Unknown packet:", JSON.stringify(packet));
|
||||
}
|
||||
|
||||
// Mark dirty — flushViaRAF coalesces bursts into one React update per frame.
|
||||
if (!isMultiModel) singleModelDirty = true;
|
||||
pendingFlush = true;
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || currentMessageTreeLocal?.get(SYSTEM_NODE_ID)!;
|
||||
|
||||
// Build the messages to upsert based on single vs multi-model mode
|
||||
let messagesToUpsertInLoop: Message[];
|
||||
|
||||
if (isMultiModel) {
|
||||
// Read the current user node from the tree to preserve childrenNodeIds
|
||||
// (initialUserNode has stale/empty children from creation time).
|
||||
const currentUserNode =
|
||||
currentMessageTreeLocal.get(initialUserNode.nodeId) ||
|
||||
initialUserNode;
|
||||
const updatedUserNode: Message = {
|
||||
...currentUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
};
|
||||
messagesToUpsertInLoop = [
|
||||
updatedUserNode,
|
||||
...buildNonErroredNodes(),
|
||||
];
|
||||
} else {
|
||||
messagesToUpsertInLoop = [
|
||||
{
|
||||
...initialUserNode,
|
||||
messageId: newUserMessageId ?? undefined,
|
||||
files: files,
|
||||
},
|
||||
{
|
||||
...initialAgentNode,
|
||||
messageId: newAgentMessageId ?? undefined,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
packets: packets,
|
||||
packetCount: packets.length,
|
||||
processingDurationSeconds:
|
||||
finalMessage?.processing_duration_seconds ??
|
||||
(() => {
|
||||
const startTime = useChatSessionStore
|
||||
.getState()
|
||||
.getStreamingStartTime(frozenSessionId);
|
||||
return startTime
|
||||
? Math.floor((Date.now() - startTime) / 1000)
|
||||
: undefined;
|
||||
})(),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
currentMessageTreeLocal = upsertToCompleteMessageTree({
|
||||
messages: messagesToUpsertInLoop,
|
||||
completeMessageTreeOverride: currentMessageTreeLocal,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Flush any tail state from the final packet(s) before declaring
|
||||
// the stream complete. Without this, the last ≤1 frame of packets
|
||||
// could get stranded in local state.
|
||||
flushPendingUpdates();
|
||||
|
||||
// Surface FIFO errors (e.g. 429 before any packets arrive) so the
|
||||
// catch block replaces the thinking placeholder with an error message.
|
||||
if (stack.error) {
|
||||
@@ -1257,7 +1174,6 @@ export default function useChatController({
|
||||
errorCode,
|
||||
isRetryable,
|
||||
errorDetails,
|
||||
is_generating: false,
|
||||
})
|
||||
: [
|
||||
{
|
||||
|
||||
@@ -106,23 +106,9 @@ export default function useMultiModelChat(
|
||||
[currentLlmModel]
|
||||
);
|
||||
|
||||
const removeModel = useCallback(
|
||||
(index: number) => {
|
||||
const next = selectedModels.filter((_, i) => i !== index);
|
||||
// When dropping to single-model, switch llmManager to the surviving
|
||||
// model so it becomes the active model instead of reverting to the
|
||||
// user's default.
|
||||
if (next.length === 1 && next[0]) {
|
||||
llmManager.updateCurrentLlm({
|
||||
name: next[0].name,
|
||||
provider: next[0].provider,
|
||||
modelName: next[0].modelName,
|
||||
});
|
||||
}
|
||||
setSelectedModels(next);
|
||||
},
|
||||
[selectedModels, llmManager]
|
||||
);
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
|
||||
const replaceModel = useCallback(
|
||||
(index: number, model: SelectedModel) => {
|
||||
|
||||
@@ -48,7 +48,6 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
});
|
||||
@@ -66,7 +65,6 @@ describe("useSettings", () => {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
|
||||
@@ -23,7 +23,6 @@ const DEFAULT_SETTINGS = {
|
||||
anonymous_user_enabled: false,
|
||||
invite_only_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
multi_model_chat_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
} satisfies Settings;
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
// Fixed reveal rate — NOT adaptive. Any ceil(delta/N) formula produces
|
||||
// visible chunks on burst packet arrivals. 1 = 60 cps, 2 = 120 cps.
|
||||
const CHARS_PER_FRAME = 3;
|
||||
|
||||
/**
|
||||
* Reveals `target` one character at a time on each animation frame.
|
||||
* When `enabled` is false (historical messages), snaps to full on mount.
|
||||
* The rAF loop pauses once caught up and resumes when `target` grows.
|
||||
*/
|
||||
export function useTypewriter(target: string, enabled: boolean): string {
|
||||
// Ref so the rAF loop reads latest length without restarting.
|
||||
const targetRef = useRef(target);
|
||||
targetRef.current = target;
|
||||
|
||||
// Mirror `enabled` so the restart effect can short-circuit when the
|
||||
// caller has turned animation off (e.g. voice-mode, where display is
|
||||
// driven by audio position — the typewriter must stay idle and not
|
||||
// animate a jump after audio ends).
|
||||
const enabledRef = useRef(enabled);
|
||||
enabledRef.current = enabled;
|
||||
|
||||
// `enabled` controls initial state: animate from 0 vs snap to full for
|
||||
// history/voice. Transitions mid-stream are handled via enabledRef in
|
||||
// the restart effect so a flip to false doesn't dump the buffered tail
|
||||
// *and* doesn't spin up the rAF loop on later growth.
|
||||
const [displayedLength, setDisplayedLength] = useState<number>(
|
||||
enabled ? 0 : target.length
|
||||
);
|
||||
|
||||
// Mirror displayedLength in a ref so the rAF loop can read the latest
|
||||
// value without stale-closure issues AND without needing a functional
|
||||
// state updater (which must be pure — no ref mutations inside).
|
||||
const displayedLengthRef = useRef(displayedLength);
|
||||
|
||||
// Clamp (not reset) on target shrink — preserves already-revealed chars
|
||||
// across user-cancel freeze and regeneration.
|
||||
const prevTargetLengthRef = useRef(target.length);
|
||||
useEffect(() => {
|
||||
if (target.length < prevTargetLengthRef.current) {
|
||||
const clamped = Math.min(displayedLengthRef.current, target.length);
|
||||
displayedLengthRef.current = clamped;
|
||||
setDisplayedLength(clamped);
|
||||
}
|
||||
prevTargetLengthRef.current = target.length;
|
||||
}, [target.length]);
|
||||
|
||||
// Self-scheduling rAF loop. Pauses when caught up so idle/historical
|
||||
// messages don't run a 60fps no-op updater for their entire lifetime.
|
||||
const rafIdRef = useRef<number | null>(null);
|
||||
const runningRef = useRef(false);
|
||||
const startLoopRef = useRef<(() => void) | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const tick = () => {
|
||||
const targetLen = targetRef.current.length;
|
||||
const prev = displayedLengthRef.current;
|
||||
if (prev >= targetLen) {
|
||||
// Caught up — pause the loop. The sibling effect below will
|
||||
// restart it when `target` grows.
|
||||
runningRef.current = false;
|
||||
rafIdRef.current = null;
|
||||
return;
|
||||
}
|
||||
const next = Math.min(prev + CHARS_PER_FRAME, targetLen);
|
||||
displayedLengthRef.current = next;
|
||||
setDisplayedLength(next);
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
const start = () => {
|
||||
if (runningRef.current) return;
|
||||
// Animation disabled — snap to full and stay idle. This is the
|
||||
// voice-mode path where content is driven by audio position, and
|
||||
// any "gap" (e.g. user stops audio early) must jump instantly
|
||||
// instead of animating a 1500-char typewriter burst.
|
||||
if (!enabledRef.current) {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current !== targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
return;
|
||||
}
|
||||
runningRef.current = true;
|
||||
rafIdRef.current = requestAnimationFrame(tick);
|
||||
};
|
||||
|
||||
startLoopRef.current = start;
|
||||
|
||||
if (targetRef.current.length > displayedLengthRef.current) {
|
||||
start();
|
||||
}
|
||||
|
||||
return () => {
|
||||
runningRef.current = false;
|
||||
if (rafIdRef.current !== null) {
|
||||
cancelAnimationFrame(rafIdRef.current);
|
||||
rafIdRef.current = null;
|
||||
}
|
||||
startLoopRef.current = null;
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Restart the loop when target grows past what's currently displayed.
|
||||
useEffect(() => {
|
||||
if (target.length > displayedLength && startLoopRef.current) {
|
||||
startLoopRef.current();
|
||||
}
|
||||
}, [target.length, displayedLength]);
|
||||
|
||||
// When the user navigates away and back (tab switch, window focus),
|
||||
// snap to all collected content so they see the full response immediately.
|
||||
useEffect(() => {
|
||||
const handleVisibility = () => {
|
||||
if (document.visibilityState === "visible") {
|
||||
const targetLen = targetRef.current.length;
|
||||
if (displayedLengthRef.current < targetLen) {
|
||||
displayedLengthRef.current = targetLen;
|
||||
setDisplayedLength(targetLen);
|
||||
}
|
||||
}
|
||||
};
|
||||
document.addEventListener("visibilitychange", handleVisibility);
|
||||
return () =>
|
||||
document.removeEventListener("visibilitychange", handleVisibility);
|
||||
}, []);
|
||||
|
||||
return useMemo(
|
||||
() => target.slice(0, Math.min(displayedLength, target.length)),
|
||||
[target, displayedLength]
|
||||
);
|
||||
}
|
||||
@@ -27,7 +27,6 @@ export interface Settings {
|
||||
query_history_type: QueryHistoryType;
|
||||
|
||||
deep_research_enabled?: boolean;
|
||||
multi_model_chat_enabled?: boolean;
|
||||
search_ui_enabled?: boolean;
|
||||
|
||||
// Image processing settings
|
||||
|
||||
@@ -173,13 +173,8 @@ function AttachmentItemLayout({
|
||||
rightChildren,
|
||||
}: AttachmentItemLayoutProps) {
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
gap={0.25}
|
||||
padding={0.25}
|
||||
>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08 flex-shrink-0")}>
|
||||
<Section flexDirection="row" gap={0.25} padding={0.25}>
|
||||
<div className={cn("h-[2.25rem] aspect-square rounded-08")}>
|
||||
<Section>
|
||||
<div
|
||||
className="attachment-button__icon-wrapper"
|
||||
@@ -194,7 +189,6 @@ function AttachmentItemLayout({
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
gap={1.5}
|
||||
className="min-w-0"
|
||||
>
|
||||
<div data-testid="attachment-item-title" className="flex-1 min-w-0">
|
||||
<Content
|
||||
|
||||
@@ -9,7 +9,6 @@ import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import Label from "@/refresh-components/form/Label";
|
||||
import type { TagProps } from "@opal/components/tag/components";
|
||||
|
||||
interface OrientationLayoutProps {
|
||||
name?: string;
|
||||
@@ -17,8 +16,6 @@ interface OrientationLayoutProps {
|
||||
nonInteractive?: boolean;
|
||||
children?: React.ReactNode;
|
||||
title: string | RichStr;
|
||||
/** Tag rendered inline beside the title (passed through to Content). */
|
||||
tag?: TagProps;
|
||||
description?: string | RichStr;
|
||||
suffix?: "optional" | (string & {});
|
||||
sizePreset?: "main-content" | "main-ui";
|
||||
@@ -131,7 +128,6 @@ function HorizontalInputLayout({
|
||||
children,
|
||||
center,
|
||||
title,
|
||||
tag,
|
||||
description,
|
||||
suffix,
|
||||
sizePreset = "main-content",
|
||||
@@ -148,7 +144,6 @@ function HorizontalInputLayout({
|
||||
title={title}
|
||||
description={description}
|
||||
suffix={suffix}
|
||||
tag={tag}
|
||||
sizePreset={sizePreset}
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
|
||||
@@ -694,25 +694,6 @@ export function useLlmManager(
|
||||
prevAgentIdRef.current = liveAgent?.id;
|
||||
}, [liveAgent?.id]);
|
||||
|
||||
// Clear manual override when arriving at a *different* existing session
|
||||
// from any previously-seen defined session. Tracks only the last
|
||||
// *defined* session id so a round-trip through new-chat (A → undefined
|
||||
// → B) still resets, while A → undefined (new-chat) preserves it.
|
||||
const prevDefinedSessionIdRef = useRef<string | undefined>(undefined);
|
||||
useEffect(() => {
|
||||
const nextId = currentChatSession?.id;
|
||||
if (
|
||||
nextId !== undefined &&
|
||||
prevDefinedSessionIdRef.current !== undefined &&
|
||||
nextId !== prevDefinedSessionIdRef.current
|
||||
) {
|
||||
setUserHasManuallyOverriddenLLM(false);
|
||||
}
|
||||
if (nextId !== undefined) {
|
||||
prevDefinedSessionIdRef.current = nextId;
|
||||
}
|
||||
}, [currentChatSession?.id]);
|
||||
|
||||
function getValidLlmDescriptor(
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor {
|
||||
@@ -734,9 +715,8 @@ export function useLlmManager(
|
||||
|
||||
if (llmProviders === undefined || llmProviders === null) {
|
||||
resolved = manualLlm;
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// Manual override wins over session's `current_alternate_model`.
|
||||
// Cleared on cross-session navigation by the effect above.
|
||||
} else if (userHasManuallyOverriddenLLM && !currentChatSession) {
|
||||
// User has overridden in this session and switched to a new session
|
||||
resolved = manualLlm;
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
@@ -748,6 +728,8 @@ export function useLlmManager(
|
||||
liveAgent.llm_model_version_override,
|
||||
llmProviders
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
resolved = manualLlm;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
resolved = getValidLlmDescriptorForProviders(
|
||||
user.preferences.default_model,
|
||||
|
||||
@@ -53,17 +53,18 @@ export class HTTPStreamingTTSPlayer {
|
||||
// Create abort controller for this request
|
||||
this.abortController = new AbortController();
|
||||
|
||||
const url = this.getAPIUrl();
|
||||
const body = JSON.stringify({
|
||||
text,
|
||||
...(voice && { voice }),
|
||||
speed,
|
||||
});
|
||||
// Build URL with query params
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
params.set("speed", speed.toString());
|
||||
|
||||
const url = `${this.getAPIUrl()}?${params}`;
|
||||
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
// Fallback to simple buffered playback
|
||||
return this.fallbackSpeak(url, body);
|
||||
return this.fallbackSpeak(url);
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
@@ -128,21 +129,15 @@ export class HTTPStreamingTTSPlayer {
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body,
|
||||
signal: this.abortController.signal,
|
||||
credentials: "include",
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let message = `TTS request failed (${response.status})`;
|
||||
try {
|
||||
const errorJson = await response.json();
|
||||
if (errorJson.detail) message = errorJson.detail;
|
||||
} catch {
|
||||
// response wasn't JSON — use status text
|
||||
}
|
||||
throw new Error(message);
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`TTS request failed: ${response.status} - ${errorText}`
|
||||
);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
@@ -247,24 +242,16 @@ export class HTTPStreamingTTSPlayer {
|
||||
* Fallback for browsers that don't support MediaSource Extensions.
|
||||
* Buffers all audio before playing.
|
||||
*/
|
||||
private async fallbackSpeak(url: string, body: string): Promise<void> {
|
||||
private async fallbackSpeak(url: string): Promise<void> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body,
|
||||
signal: this.abortController?.signal,
|
||||
credentials: "include",
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let message = `TTS request failed (${response.status})`;
|
||||
try {
|
||||
const errorJson = await response.json();
|
||||
if (errorJson.detail) message = errorJson.detail;
|
||||
} catch {
|
||||
// response wasn't JSON — use status text
|
||||
}
|
||||
throw new Error(message);
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const audioData = await response.arrayBuffer();
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user