mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-17 23:46:47 +00:00
Compare commits
43 Commits
v3.2.0-clo
...
bo/pruning
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3e857c338 | ||
|
|
146d8522df | ||
|
|
ac5bae3631 | ||
|
|
b5434b2391 | ||
|
|
28e13b503b | ||
|
|
99a90ec196 | ||
|
|
8ffd7fbb56 | ||
|
|
f9e88e3c72 | ||
|
|
97efdbbbc3 | ||
|
|
b91a3aed53 | ||
|
|
51480e1099 | ||
|
|
70efbef95e | ||
|
|
f3936e2669 | ||
|
|
c933c71b59 | ||
|
|
e0d9e109b5 | ||
|
|
66c361bd37 | ||
|
|
01cbea8c4b | ||
|
|
2dc2b0da84 | ||
|
|
4b58c9cda6 | ||
|
|
7eb945f060 | ||
|
|
e29f948f29 | ||
|
|
7a18b896aa | ||
|
|
53e00c7989 | ||
|
|
50df53727a | ||
|
|
e629574580 | ||
|
|
8d539cdf3f | ||
|
|
52524cbe57 | ||
|
|
c64def6a9e | ||
|
|
2628fe1b93 | ||
|
|
96bf344f9c | ||
|
|
b92d3a307d | ||
|
|
c55207eeba | ||
|
|
2de56cd65f | ||
|
|
92bc13f920 | ||
|
|
3ddcf101bf | ||
|
|
9f764ee55f | ||
|
|
4d059b5e0f | ||
|
|
57e78cf4c9 | ||
|
|
48e74ad3ef | ||
|
|
ca10520190 | ||
|
|
d128508838 | ||
|
|
f64cd1dd63 | ||
|
|
210d11aa5d |
@@ -1,6 +1,7 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
@@ -61,3 +62,11 @@ RUN chsh -s /bin/zsh root && \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
|
||||
# Pre-seed GitHub's SSH host keys so git-over-SSH never prompts. Keys are
|
||||
# pinned in-repo (verified against the fingerprints GitHub publishes at
|
||||
# https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/githubs-ssh-key-fingerprints)
|
||||
# rather than fetched at build time, so a compromised build-time network can't
|
||||
# inject a rogue key.
|
||||
COPY github_known_hosts /etc/ssh/ssh_known_hosts
|
||||
RUN chmod 644 /etc/ssh/ssh_known_hosts
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:4986c9252289b660ce772b45f0488b938fe425d8114245e96ef64b273b3fcee4",
|
||||
"runArgs": [
|
||||
"--cap-add=NET_ADMIN",
|
||||
"--cap-add=NET_RAW",
|
||||
"--network=onyx_default"
|
||||
],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
|
||||
3
.devcontainer/github_known_hosts
Normal file
3
.devcontainer/github_known_hosts
Normal file
@@ -0,0 +1,3 @@
|
||||
github.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXkE2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMjA2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIEs4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+EjqoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/WnwH6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk=
|
||||
github.com ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEmKSENjQEezOmxkZMy7opKgwFB9nkt5YRrYMjNuG5N87uRgg6CLrbo5wAdT/y6v0mKV0U2w0WZ2YB/++Tpockg=
|
||||
github.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl
|
||||
@@ -45,7 +45,7 @@ if [ "$ACTIVE_HOME" != "$MOUNT_HOME" ]; then
|
||||
[ -d "$MOUNT_HOME/$item" ] || continue
|
||||
if [ -e "$ACTIVE_HOME/$item" ] && [ ! -L "$ACTIVE_HOME/$item" ]; then
|
||||
echo "warning: replacing $ACTIVE_HOME/$item with symlink to $MOUNT_HOME/$item" >&2
|
||||
rm -rf "$ACTIVE_HOME/$item"
|
||||
rm -rf "${ACTIVE_HOME:?}/$item"
|
||||
fi
|
||||
ln -sfn "$MOUNT_HOME/$item" "$ACTIVE_HOME/$item"
|
||||
done
|
||||
|
||||
@@ -4,6 +4,17 @@ set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Reset default policies to ACCEPT before flushing rules. On re-runs the
|
||||
# previous invocation's DROP policies are still in effect; flushing rules while
|
||||
# the default is DROP would block the DNS lookups below. Register a trap so
|
||||
# that if the script exits before the DROP policies are re-applied at the end,
|
||||
# we fail closed instead of leaving the container with an unrestricted
|
||||
# firewall.
|
||||
trap 'iptables -P INPUT DROP; iptables -P OUTPUT DROP; iptables -P FORWARD DROP' EXIT
|
||||
iptables -P INPUT ACCEPT
|
||||
iptables -P OUTPUT ACCEPT
|
||||
iptables -P FORWARD ACCEPT
|
||||
|
||||
# Only flush the filter table. The nat and mangle tables are managed by Docker
|
||||
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
|
||||
# flushing them breaks Docker's embedded DNS resolver.
|
||||
@@ -34,8 +45,16 @@ ALLOWED_DOMAINS=(
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"proxy.golang.org"
|
||||
"sum.golang.org"
|
||||
"storage.googleapis.com"
|
||||
"dl.google.com"
|
||||
"static.rust-lang.org"
|
||||
"index.crates.io"
|
||||
"static.crates.io"
|
||||
"archive.ubuntu.com"
|
||||
"security.ubuntu.com"
|
||||
"deb.nodesource.com"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
|
||||
50
.github/workflows/deployment.yml
vendored
50
.github/workflows/deployment.yml
vendored
@@ -462,7 +462,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -472,7 +472,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -536,7 +536,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -546,7 +546,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -597,7 +597,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -676,7 +676,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -686,7 +686,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -761,7 +761,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -771,7 +771,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -833,7 +833,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -908,7 +908,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -918,7 +918,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -981,7 +981,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -991,7 +991,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -1041,7 +1041,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -1119,7 +1119,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -1129,7 +1129,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -1192,7 +1192,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -1202,7 +1202,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -1253,7 +1253,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -1329,7 +1329,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
@@ -1341,7 +1341,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
@@ -1409,7 +1409,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
@@ -1421,7 +1421,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
env:
|
||||
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
|
||||
with:
|
||||
@@ -1475,7 +1475,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
|
||||
10
.github/workflows/pr-integration-tests.yml
vendored
10
.github/workflows/pr-integration-tests.yml
vendored
@@ -115,7 +115,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -175,7 +175,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -220,7 +220,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
12
.github/workflows/pr-playwright-tests.yml
vendored
12
.github/workflows/pr-playwright-tests.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -105,7 +105,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Web Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -227,7 +227,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -69,7 +69,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -39,6 +39,8 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@cbc2f23eb5539cf20d82d1aabd0d0ecbcc56f4e3
|
||||
env:
|
||||
SKIP: ty
|
||||
with:
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
10
.github/workflows/sandbox-deployment.yml
vendored
10
.github/workflows/sandbox-deployment.yml
vendored
@@ -132,7 +132,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
@@ -202,7 +202,7 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
@@ -212,7 +212,7 @@ jobs:
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
@@ -258,7 +258,7 @@ jobs:
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
|
||||
@@ -68,6 +68,7 @@ repos:
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
- id: uv-run
|
||||
alias: ty
|
||||
name: ty
|
||||
args: ["ty", "check"]
|
||||
pass_filenames: true
|
||||
@@ -85,6 +86,17 @@ repos:
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/shellcheck-py/shellcheck-py
|
||||
rev: 745eface02aef23e168a8afb6b5737818efbea95 # frozen: v0.11.0.1
|
||||
hooks:
|
||||
- id: shellcheck
|
||||
exclude: >-
|
||||
(?x)^(
|
||||
backend/scripts/setup_craft_templates\.sh|
|
||||
deployment/docker_compose/init-letsencrypt\.sh|
|
||||
deployment/docker_compose/install\.sh
|
||||
)$
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
@@ -141,6 +153,7 @@ repos:
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --strict-ignore
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
1
.secretsignore
Normal file
1
.secretsignore
Normal file
@@ -0,0 +1 @@
|
||||
.devcontainer/github_known_hosts
|
||||
@@ -1,8 +1,10 @@
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
@@ -16,9 +18,56 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_tenant_work_gating import cleanup_expired
|
||||
from onyx.redis.redis_tenant_work_gating import get_active_tenants
|
||||
from onyx.redis.redis_tenant_work_gating import observe_active_set_size
|
||||
from onyx.redis.redis_tenant_work_gating import record_full_fanout_cycle
|
||||
from onyx.redis.redis_tenant_work_gating import record_gate_decision
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
|
||||
_FULL_FANOUT_TIMESTAMP_KEY_PREFIX = "tenant_work_gating_last_full_fanout_ms"
|
||||
|
||||
|
||||
def _should_bypass_gate_for_full_fanout(
|
||||
redis_client: Redis, task_name: str, interval_seconds: int
|
||||
) -> bool:
|
||||
"""True if at least `interval_seconds` have elapsed since the last
|
||||
full-fanout bypass for this task. On True, updates the stored timestamp
|
||||
atomically-enough (it's a best-effort counter, not a lock)."""
|
||||
key = f"{_FULL_FANOUT_TIMESTAMP_KEY_PREFIX}:{task_name}"
|
||||
now_ms = int(time.time() * 1000)
|
||||
threshold_ms = now_ms - (interval_seconds * 1000)
|
||||
|
||||
try:
|
||||
raw = cast(bytes | None, redis_client.get(key))
|
||||
except Exception:
|
||||
task_logger.exception(f"full-fanout timestamp read failed: task={task_name}")
|
||||
# Fail open: treat as "interval elapsed" so we don't skip every
|
||||
# tenant during a Redis hiccup.
|
||||
return True
|
||||
|
||||
if raw is None:
|
||||
# First invocation — bypass so the set seeds cleanly.
|
||||
elapsed = True
|
||||
else:
|
||||
try:
|
||||
last_ms = int(raw.decode())
|
||||
elapsed = last_ms <= threshold_ms
|
||||
except ValueError:
|
||||
elapsed = True
|
||||
|
||||
if elapsed:
|
||||
try:
|
||||
redis_client.set(key, str(now_ms))
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"full-fanout timestamp write failed: task={task_name}"
|
||||
)
|
||||
return elapsed
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
@@ -32,6 +81,7 @@ def cloud_beat_task_generator(
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
work_gated: bool = False,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -51,8 +101,56 @@ def cloud_beat_task_generator(
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
num_would_skip_work_gate = 0
|
||||
num_skipped_work_gate = 0
|
||||
|
||||
# Tenant-work-gating read path. Resolve once per invocation.
|
||||
gate_enabled = False
|
||||
gate_enforce = False
|
||||
full_fanout_cycle = False
|
||||
active_tenants: set[str] | None = None
|
||||
|
||||
try:
|
||||
# Gating setup is inside the try block so any exception still
|
||||
# reaches the finally that releases the beat lock.
|
||||
if work_gated:
|
||||
try:
|
||||
gate_enabled = OnyxRuntime.get_tenant_work_gating_enabled()
|
||||
gate_enforce = OnyxRuntime.get_tenant_work_gating_enforce()
|
||||
except Exception:
|
||||
task_logger.exception("tenant work gating: runtime flag read failed")
|
||||
gate_enabled = False
|
||||
|
||||
if gate_enabled:
|
||||
redis_failed = False
|
||||
interval_s = (
|
||||
OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds()
|
||||
)
|
||||
full_fanout_cycle = _should_bypass_gate_for_full_fanout(
|
||||
redis_client, task_name, interval_s
|
||||
)
|
||||
if full_fanout_cycle:
|
||||
record_full_fanout_cycle(task_name)
|
||||
try:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
cleanup_expired(ttl_s)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"tenant work gating: cleanup_expired failed"
|
||||
)
|
||||
else:
|
||||
ttl_s = OnyxRuntime.get_tenant_work_gating_ttl_seconds()
|
||||
active_tenants = get_active_tenants(ttl_s)
|
||||
if active_tenants is None:
|
||||
full_fanout_cycle = True
|
||||
record_full_fanout_cycle(task_name)
|
||||
redis_failed = True
|
||||
|
||||
# Only refresh the gauge when Redis is known-reachable —
|
||||
# skip the ZCARD if we just failed open due to a Redis error.
|
||||
if not redis_failed:
|
||||
observe_active_set_size()
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
@@ -76,6 +174,21 @@ def cloud_beat_task_generator(
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
# Tenant work gate: if the feature is on, check membership. Skip
|
||||
# unmarked tenants when enforce=True AND we're not in a full-
|
||||
# fanout cycle. Always log/emit the shadow counter.
|
||||
if work_gated and gate_enabled and not full_fanout_cycle:
|
||||
would_skip = (
|
||||
active_tenants is not None and tenant_id not in active_tenants
|
||||
)
|
||||
if would_skip:
|
||||
num_would_skip_work_gate += 1
|
||||
if gate_enforce:
|
||||
num_skipped_work_gate += 1
|
||||
record_gate_decision(task_name, skipped=True)
|
||||
continue
|
||||
record_gate_decision(task_name, skipped=False)
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
@@ -109,6 +222,12 @@ def cloud_beat_task_generator(
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_would_skip_work_gate={num_would_skip_work_gate} "
|
||||
f"num_skipped_work_gate={num_skipped_work_gate} "
|
||||
f"full_fanout_cycle={full_fanout_cycle} "
|
||||
f"work_gated={work_gated} "
|
||||
f"gate_enabled={gate_enabled} "
|
||||
f"gate_enforce={gate_enforce} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFI
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -531,23 +532,26 @@ def reset_tenant_id(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(
|
||||
sender: Any, # noqa: ARG001
|
||||
**kwargs: Any, # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
def wait_for_document_index_or_shutdown() -> None:
|
||||
"""
|
||||
Waits for all configured document indices to become ready subject to a
|
||||
timeout.
|
||||
|
||||
Raises WorkerShutdown if the timeout is reached.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
if not ONYX_DISABLE_VESPA:
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = (
|
||||
"[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
|
||||
@@ -105,7 +105,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -111,7 +111,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -97,7 +97,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -118,7 +118,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -124,7 +124,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_document_index_or_shutdown()
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -67,6 +68,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -100,6 +102,7 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -109,6 +112,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -118,6 +122,7 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -155,6 +160,7 @@ beat_task_templates: list[dict] = [
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -179,6 +185,7 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -188,6 +195,7 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"work_gated": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
@@ -227,7 +235,11 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
)
|
||||
|
||||
# Add OpenSearch migration task if enabled.
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
|
||||
if (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and not DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
and not ONYX_DISABLE_VESPA
|
||||
):
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
@@ -280,7 +292,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated", "work_gated"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -373,12 +385,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
# `skip_gated` and `work_gated` are cloud-only hints consumed by
|
||||
# `cloud_beat_task_generator`. Strip them before extending the self-hosted
|
||||
# schedule so they don't leak into apply_async as unrecognised options on
|
||||
# every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
_self_hosted_template["options"].pop("work_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
|
||||
|
||||
|
||||
@@ -166,16 +166,21 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
# collect cc_pair_ids and note whether any are in DELETING status
|
||||
cc_pair_ids: list[int] = []
|
||||
has_deleting_cc_pair = False
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
has_deleting_cc_pair = True
|
||||
|
||||
# Tenant-work-gating hook: any cc_pair means deletion could have
|
||||
# cleanup work to do for this tenant on some cycle.
|
||||
if cc_pair_ids:
|
||||
# Tenant-work-gating hook: mark only when at least one cc_pair is in
|
||||
# DELETING status. Marking on bare cc_pair existence would keep
|
||||
# nearly every tenant in the active set since most have cc_pairs
|
||||
# but almost none are actively being deleted on any given cycle.
|
||||
if has_deleting_cc_pair:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
|
||||
@@ -897,11 +897,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever indexing actually has work to dispatch.
|
||||
if primary_cc_pair_ids or secondary_cc_pair_ids:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
@@ -1019,6 +1014,14 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
f"Skipping secondary indexing: switchover_type=INSTANT for search_settings={secondary_search_settings.id}"
|
||||
)
|
||||
|
||||
# Tenant-work-gating hook: refresh membership only when indexing
|
||||
# actually dispatched at least one docfetching task. `_kickoff_indexing_tasks`
|
||||
# internally calls `should_index()` to decide per-cc_pair; using
|
||||
# `tasks_created > 0` here gives us a "real work was done" signal
|
||||
# rather than just "tenant has a cc_pair somewhere."
|
||||
if tasks_created > 0:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# 2/3: VALIDATE
|
||||
# Check for inconsistent index attempts - active attempts without task IDs
|
||||
# This can happen if attempt creation fails partway through
|
||||
|
||||
@@ -229,11 +229,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# Tenant-work-gating hook: any cc_pair means pruning could have
|
||||
# work to do for this tenant on some cycle.
|
||||
if cc_pair_ids:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
prune_dispatched = False
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -256,9 +252,18 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
logger.info(f"Pruning not created: {cc_pair_id}")
|
||||
continue
|
||||
|
||||
prune_dispatched = True
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
|
||||
# Tenant-work-gating hook: mark only when at least one cc_pair
|
||||
# was actually due for pruning AND a prune task was dispatched.
|
||||
# Marking on bare cc_pair existence over-counts the population
|
||||
# since most tenants have cc_pairs but almost none are due on
|
||||
# any given cycle.
|
||||
if prune_dispatched:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=_get_pruning_block_expiration())
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
|
||||
@@ -826,6 +826,12 @@ def translate_history_to_llm_format(
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
content_parts.append(
|
||||
TextContentPart(
|
||||
type="text",
|
||||
text=f"[attached image — file_id: {img_file.file_id}]",
|
||||
)
|
||||
)
|
||||
image_part = ImageContentPart(
|
||||
type="image_url",
|
||||
image_url=ImageUrlDetail(
|
||||
|
||||
@@ -282,6 +282,7 @@ OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
OPENSEARCH_USE_SSL = os.environ.get("OPENSEARCH_USE_SSL", "true").lower() == "true"
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -327,6 +328,7 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
DISABLE_OPENSEARCH_MIGRATION_TASK = (
|
||||
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
|
||||
)
|
||||
ONYX_DISABLE_VESPA = os.environ.get("ONYX_DISABLE_VESPA", "").lower() == "true"
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
@@ -843,6 +845,29 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Maximum embedded images allowed in a single file. PDFs (and other formats)
|
||||
# with thousands of embedded images can OOM the user-file-processing worker
|
||||
# because every image is decoded with PIL and then sent to the vision LLM.
|
||||
# Enforced both at upload time (rejects the file) and during extraction
|
||||
# (defense-in-depth: caps the number of images materialized).
|
||||
#
|
||||
# Clamped to >= 0; a negative env value would turn upload validation into
|
||||
# always-fail and extraction into always-stop, which is never desired. 0
|
||||
# disables image extraction entirely, which is a valid (if aggressive) setting.
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
|
||||
)
|
||||
|
||||
# Maximum embedded images allowed across all files in a single upload batch.
|
||||
# Protects against the scenario where a user uploads many files that each
|
||||
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
|
||||
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
|
||||
# worker under concurrency or run up surprise latency/cost. Also clamped
|
||||
# to >= 0.
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
@@ -10,7 +11,6 @@ from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
from dateutil.parser import ParserError
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -56,18 +56,16 @@ def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
if fixed not in candidates:
|
||||
candidates.append(fixed)
|
||||
|
||||
last_exception: Exception | None = None
|
||||
for candidate in candidates:
|
||||
try:
|
||||
dt = parse(candidate)
|
||||
return datetime_to_utc(dt)
|
||||
except (ValueError, ParserError) as exc:
|
||||
last_exception = exc
|
||||
# dateutil is the primary; the stdlib RFC 2822 parser is a fallback for
|
||||
# inputs dateutil rejects (e.g. headers concatenated without a CRLF —
|
||||
# TZ may be dropped, datetime_to_utc then assumes UTC).
|
||||
for parser in (parse, parsedate_to_datetime):
|
||||
for candidate in candidates:
|
||||
try:
|
||||
return datetime_to_utc(parser(candidate))
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
continue
|
||||
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
|
||||
# Fallback in case parsing failed without raising (should not happen)
|
||||
raise ValueError(f"Unable to parse datetime string: {datetime_str}")
|
||||
|
||||
|
||||
|
||||
@@ -253,7 +253,17 @@ def thread_to_document(
|
||||
|
||||
updated_at_datetime = None
|
||||
if updated_at:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
try:
|
||||
updated_at_datetime = time_str_to_utc(updated_at)
|
||||
except (ValueError, OverflowError) as e:
|
||||
# Old mailboxes contain RFC-violating Date headers. Drop the
|
||||
# timestamp instead of aborting the indexing run.
|
||||
logger.warning(
|
||||
"Skipping unparseable Gmail Date header on thread %s: %r (%s)",
|
||||
full_thread.get("id"),
|
||||
updated_at,
|
||||
e,
|
||||
)
|
||||
|
||||
id = full_thread.get("id")
|
||||
if not id:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
@@ -8,27 +9,58 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util import Retry
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import GONG_CONNECTOR_START_TIME
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GongConnector(LoadConnector, PollConnector):
|
||||
class GongConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# Resolved workspace IDs to iterate through.
|
||||
# None means "not yet resolved" — first checkpoint call resolves them.
|
||||
# Inner None means "no workspace filter" (fetch all).
|
||||
workspace_ids: list[str | None] | None = None
|
||||
# Index into workspace_ids for current workspace
|
||||
workspace_index: int = 0
|
||||
# Gong API cursor for current workspace's transcript pagination
|
||||
cursor: str | None = None
|
||||
# Cached time range — computed once, reused across checkpoint calls
|
||||
time_range: tuple[str, str] | None = None
|
||||
|
||||
|
||||
class _TranscriptPage(BaseModel):
|
||||
"""One page of transcripts from /v2/calls/transcript."""
|
||||
|
||||
transcripts: list[dict[str, Any]]
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class _CursorExpiredError(Exception):
|
||||
"""Raised when Gong rejects a pagination cursor as expired.
|
||||
|
||||
Gong pagination cursors TTL is ~1 hour from the first request in a
|
||||
pagination sequence, not from the last cursor fetch. Since checkpointed
|
||||
connector runs can pause between invocations, a resumed run may encounter
|
||||
an expired cursor and must restart the current workspace from scratch.
|
||||
See https://visioneers.gong.io/integrations-77/pagination-cursor-expires-after-1-hours-even-for-a-new-cursor-1382
|
||||
"""
|
||||
|
||||
|
||||
class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
@@ -38,13 +70,9 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspaces: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_fail: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
hide_user_info: bool = False,
|
||||
) -> None:
|
||||
self.workspaces = workspaces
|
||||
self.batch_size: int = batch_size
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
@@ -98,67 +126,50 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# Then the user input is treated as the name
|
||||
return {**id_id_map, **name_id_map}
|
||||
|
||||
def _get_transcript_batches(
|
||||
self, start_datetime: str | None = None, end_datetime: str | None = None
|
||||
) -> Generator[list[dict[str, Any]], None, None]:
|
||||
body: dict[str, dict] = {"filter": {}}
|
||||
def _fetch_transcript_page(
|
||||
self,
|
||||
start_datetime: str | None,
|
||||
end_datetime: str | None,
|
||||
workspace_id: str | None,
|
||||
cursor: str | None,
|
||||
) -> _TranscriptPage:
|
||||
"""Fetch one page of transcripts from the Gong API.
|
||||
|
||||
Raises _CursorExpiredError if Gong reports the pagination cursor
|
||||
expired (TTL is ~1 hour from first request in the pagination sequence).
|
||||
"""
|
||||
body: dict[str, Any] = {"filter": {}}
|
||||
if start_datetime:
|
||||
body["filter"]["fromDateTime"] = start_datetime
|
||||
if end_datetime:
|
||||
body["filter"]["toDateTime"] = end_datetime
|
||||
if workspace_id:
|
||||
body["filter"]["workspaceId"] = workspace_id
|
||||
if cursor:
|
||||
body["cursor"] = cursor
|
||||
|
||||
# The batch_ids in the previous method appears to be batches of call_ids to process
|
||||
# In this method, we will retrieve transcripts for them in batches.
|
||||
transcripts: list[dict[str, Any]] = []
|
||||
workspace_list = self.workspaces or [None]
|
||||
workspace_map = self._get_workspace_id_map() if self.workspaces else {}
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, return empty
|
||||
if response.status_code == 404:
|
||||
return _TranscriptPage(transcripts=[])
|
||||
|
||||
for workspace in workspace_list:
|
||||
if workspace:
|
||||
logger.info(f"Updating Gong workspace: {workspace}")
|
||||
workspace_id = workspace_map.get(workspace)
|
||||
if not workspace_id:
|
||||
logger.error(f"Invalid Gong workspace: {workspace}")
|
||||
if not self.continue_on_fail:
|
||||
raise ValueError(f"Invalid workspace: {workspace}")
|
||||
continue
|
||||
body["filter"]["workspaceId"] = workspace_id
|
||||
else:
|
||||
if "workspaceId" in body["filter"]:
|
||||
del body["filter"]["workspaceId"]
|
||||
if not response.ok:
|
||||
# Cursor expiration comes back as a 4xx with this error message —
|
||||
# detect it before raise_for_status so callers can restart the workspace.
|
||||
if cursor and "cursor has expired" in response.text.lower():
|
||||
raise _CursorExpiredError(response.text)
|
||||
logger.error(f"Error fetching transcripts: {response.text}")
|
||||
response.raise_for_status()
|
||||
|
||||
while True:
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
break
|
||||
data = response.json()
|
||||
return _TranscriptPage(
|
||||
transcripts=data.get("callTranscripts", []),
|
||||
next_cursor=data.get("records", {}).get("cursor"),
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
logger.error(f"Error fetching transcripts: {response.text}")
|
||||
raise
|
||||
|
||||
data = response.json()
|
||||
call_transcripts = data.get("callTranscripts", [])
|
||||
transcripts.extend(call_transcripts)
|
||||
|
||||
while len(transcripts) >= self.batch_size:
|
||||
yield transcripts[: self.batch_size]
|
||||
transcripts = transcripts[self.batch_size :]
|
||||
|
||||
cursor = data.get("records", {}).get("cursor")
|
||||
if cursor:
|
||||
body["cursor"] = cursor
|
||||
else:
|
||||
break
|
||||
|
||||
if transcripts:
|
||||
yield transcripts
|
||||
|
||||
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict:
|
||||
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict[str, Any]:
|
||||
body = {
|
||||
"filter": {"callIds": call_ids},
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
@@ -176,6 +187,50 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
|
||||
return call_to_metadata
|
||||
|
||||
def _fetch_call_details_with_retry(self, call_ids: list[str]) -> dict[str, Any]:
|
||||
"""Fetch call details with retry for the Gong API race condition.
|
||||
|
||||
The Gong API has a known race where transcript call IDs don't immediately
|
||||
appear in /v2/calls/extensive. Retries with exponential backoff, only
|
||||
re-requesting the missing IDs on each attempt.
|
||||
"""
|
||||
call_details_map = self._get_call_details_by_ids(call_ids)
|
||||
if set(call_ids) == set(call_details_map.keys()):
|
||||
return call_details_map
|
||||
|
||||
for attempt in range(2, self.MAX_CALL_DETAILS_ATTEMPTS + 1):
|
||||
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids is missing call id's: current_attempt={attempt - 1} missing_call_ids={missing_ids}"
|
||||
)
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, attempt - 2)
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids waiting to retry: "
|
||||
f"wait={wait_seconds}s "
|
||||
f"current_attempt={attempt - 1} "
|
||||
f"next_attempt={attempt} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
)
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
# Only re-fetch the missing IDs, merge into existing results
|
||||
new_details = self._get_call_details_by_ids(missing_ids)
|
||||
call_details_map.update(new_details)
|
||||
|
||||
if set(call_ids) == set(call_details_map.keys()):
|
||||
return call_details_map
|
||||
|
||||
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(call_ids)} calls"
|
||||
)
|
||||
return call_details_map
|
||||
|
||||
@staticmethod
|
||||
def _parse_parties(parties: list[dict]) -> dict[str, str]:
|
||||
id_mapping = {}
|
||||
@@ -196,186 +251,46 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
|
||||
return id_mapping
|
||||
|
||||
def _fetch_calls(
|
||||
self, start_datetime: str | None = None, end_datetime: str | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
num_calls = 0
|
||||
def _resolve_workspace_ids(self) -> list[str | None]:
|
||||
"""Resolve configured workspace names/IDs to actual workspace IDs.
|
||||
|
||||
for transcript_batch in self._get_transcript_batches(
|
||||
start_datetime, end_datetime
|
||||
):
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
Returns a list of workspace IDs. If no workspaces are configured,
|
||||
returns [None] to indicate "fetch all workspaces".
|
||||
|
||||
transcript_call_ids = cast(
|
||||
list[str],
|
||||
[t.get("callId") for t in transcript_batch if t.get("callId")],
|
||||
Raises ValueError if workspaces are configured but none resolve —
|
||||
we never silently widen scope to "fetch all" on misconfiguration,
|
||||
because that could ingest an entire Gong account by mistake.
|
||||
"""
|
||||
if not self.workspaces:
|
||||
return [None]
|
||||
|
||||
workspace_map = self._get_workspace_id_map()
|
||||
resolved: list[str | None] = []
|
||||
for workspace in self.workspaces:
|
||||
workspace_id = workspace_map.get(workspace)
|
||||
if not workspace_id:
|
||||
logger.error(f"Invalid Gong workspace: {workspace}")
|
||||
continue
|
||||
resolved.append(workspace_id)
|
||||
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
f"No valid Gong workspaces found — check workspace names/IDs in connector config. Configured: {self.workspaces}"
|
||||
)
|
||||
|
||||
call_details_map: dict[str, Any] = {}
|
||||
return resolved
|
||||
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
call_details_map = self._get_call_details_by_ids(transcript_call_ids)
|
||||
if set(transcript_call_ids) == set(call_details_map.keys()):
|
||||
# we got all the id's we were expecting ... break and continue
|
||||
break
|
||||
|
||||
# we are missing some id's. Log and retry with exponential backoff
|
||||
missing_call_ids = set(transcript_call_ids) - set(
|
||||
call_details_map.keys()
|
||||
)
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids is missing call id's: "
|
||||
f"current_attempt={current_attempt} "
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
f"_get_call_details_by_ids waiting to retry: "
|
||||
f"wait={wait_seconds}s "
|
||||
f"current_attempt={current_attempt} "
|
||||
f"next_attempt={current_attempt + 1} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
)
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
# now we can iterate per call/transcript
|
||||
for transcript in transcript_batch:
|
||||
call_id = transcript.get("callId")
|
||||
|
||||
if not call_id or call_id not in call_details_map:
|
||||
# NOTE(rkuo): seeing odd behavior where call_ids from the transcript
|
||||
# don't have call details. adding error debugging logs to trace.
|
||||
logger.error(
|
||||
f"Couldn't get call information for Call ID: {call_id}"
|
||||
)
|
||||
if call_id:
|
||||
logger.error(
|
||||
f"Call debug info: call_id={call_id} "
|
||||
f"call_ids={transcript_call_ids} "
|
||||
f"call_details_map={call_details_map.keys()}"
|
||||
)
|
||||
if not self.continue_on_fail:
|
||||
raise RuntimeError(
|
||||
f"Couldn't get call information for Call ID: {call_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
call_details = call_details_map[call_id]
|
||||
call_metadata = call_details["metaData"]
|
||||
|
||||
call_time_str = call_metadata["started"]
|
||||
call_title = call_metadata["title"]
|
||||
logger.info(
|
||||
f"{num_calls + 1}: Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
|
||||
)
|
||||
|
||||
call_parties = cast(list[dict] | None, call_details.get("parties"))
|
||||
if call_parties is None:
|
||||
logger.error(f"Couldn't get parties for Call ID: {call_id}")
|
||||
call_parties = []
|
||||
|
||||
id_to_name_map = self._parse_parties(call_parties)
|
||||
|
||||
# Keeping a separate dict here in case the parties info is incomplete
|
||||
speaker_to_name: dict[str, str] = {}
|
||||
|
||||
transcript_text = ""
|
||||
call_purpose = call_metadata["purpose"]
|
||||
if call_purpose:
|
||||
transcript_text += f"Call Description: {call_purpose}\n\n"
|
||||
|
||||
contents = transcript["transcript"]
|
||||
for segment in contents:
|
||||
speaker_id = segment.get("speakerId", "")
|
||||
if speaker_id not in speaker_to_name:
|
||||
if self.hide_user_info:
|
||||
speaker_to_name[speaker_id] = (
|
||||
f"User {len(speaker_to_name) + 1}"
|
||||
)
|
||||
else:
|
||||
speaker_to_name[speaker_id] = id_to_name_map.get(
|
||||
speaker_id, "Unknown"
|
||||
)
|
||||
|
||||
speaker_name = speaker_to_name[speaker_id]
|
||||
|
||||
sentences = segment.get("sentences", {})
|
||||
monolog = " ".join(
|
||||
[sentence.get("text", "") for sentence in sentences]
|
||||
)
|
||||
transcript_text += f"{speaker_name}: {monolog}\n\n"
|
||||
|
||||
metadata = {}
|
||||
if call_metadata.get("system"):
|
||||
metadata["client"] = call_metadata.get("system")
|
||||
# TODO calls have a clientUniqueId field, can pull that in later
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=call_id,
|
||||
sections=[
|
||||
TextSection(link=call_metadata["url"], text=transcript_text)
|
||||
],
|
||||
source=DocumentSource.GONG,
|
||||
# Should not ever be Untitled as a call cannot be made without a Title
|
||||
semantic_identifier=call_title or "Untitled",
|
||||
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={"client": call_metadata.get("system")},
|
||||
)
|
||||
)
|
||||
|
||||
num_calls += 1
|
||||
|
||||
yield doc_batch
|
||||
|
||||
logger.info(f"_fetch_calls finished: num_calls={num_calls}")
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
combined = (
|
||||
f"{credentials['gong_access_key']}:{credentials['gong_access_key_secret']}"
|
||||
)
|
||||
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
if self.auth_token_basic is None:
|
||||
raise ConnectorMissingCredentialError("Gong")
|
||||
|
||||
self._session.headers.update(
|
||||
{"Authorization": f"Basic {self.auth_token_basic}"}
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_calls()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
@staticmethod
|
||||
def _compute_time_range(
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> tuple[str, str]:
|
||||
"""Compute the start/end datetime strings for the Gong API filter,
|
||||
applying GONG_CONNECTOR_START_TIME and the 1-day offset."""
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
# if this env variable is set, don't start from a timestamp before the specified
|
||||
# start time
|
||||
# TODO: remove this once this is globally available
|
||||
if GONG_CONNECTOR_START_TIME:
|
||||
special_start_datetime = datetime.fromisoformat(GONG_CONNECTOR_START_TIME)
|
||||
special_start_datetime = special_start_datetime.replace(tzinfo=timezone.utc)
|
||||
@@ -394,11 +309,186 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# so adding a 1 day buffer and fetching by default till current time
|
||||
start_one_day_offset = start_datetime - timedelta(days=1)
|
||||
start_time = start_one_day_offset.isoformat()
|
||||
end_time = end_datetime.isoformat()
|
||||
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
return start_time, end_time
|
||||
|
||||
logger.info(f"Fetching Gong calls between {start_time} and {end_time}")
|
||||
return self._fetch_calls(start_time, end_time)
|
||||
def _process_transcripts(
|
||||
self,
|
||||
transcripts: list[dict[str, Any]],
|
||||
) -> Generator[Document | ConnectorFailure, None, None]:
|
||||
"""Process a batch of transcripts into Documents or ConnectorFailures."""
|
||||
transcript_call_ids = cast(
|
||||
list[str],
|
||||
[t.get("callId") for t in transcripts if t.get("callId")],
|
||||
)
|
||||
|
||||
call_details_map = self._fetch_call_details_with_retry(transcript_call_ids)
|
||||
|
||||
for transcript in transcripts:
|
||||
call_id = transcript.get("callId")
|
||||
|
||||
if not call_id or call_id not in call_details_map:
|
||||
logger.error(f"Couldn't get call information for Call ID: {call_id}")
|
||||
if call_id:
|
||||
logger.error(
|
||||
f"Call debug info: call_id={call_id} "
|
||||
f"call_ids={transcript_call_ids} "
|
||||
f"call_details_map={call_details_map.keys()}"
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=call_id or "unknown",
|
||||
),
|
||||
failure_message=f"Couldn't get call information for Call ID: {call_id}",
|
||||
)
|
||||
continue
|
||||
|
||||
call_details = call_details_map[call_id]
|
||||
call_metadata = call_details["metaData"]
|
||||
|
||||
call_time_str = call_metadata["started"]
|
||||
call_title = call_metadata["title"]
|
||||
logger.info(
|
||||
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
|
||||
)
|
||||
|
||||
call_parties = cast(list[dict] | None, call_details.get("parties"))
|
||||
if call_parties is None:
|
||||
logger.error(f"Couldn't get parties for Call ID: {call_id}")
|
||||
call_parties = []
|
||||
|
||||
id_to_name_map = self._parse_parties(call_parties)
|
||||
|
||||
speaker_to_name: dict[str, str] = {}
|
||||
|
||||
transcript_text = ""
|
||||
call_purpose = call_metadata["purpose"]
|
||||
if call_purpose:
|
||||
transcript_text += f"Call Description: {call_purpose}\n\n"
|
||||
|
||||
contents = transcript["transcript"]
|
||||
for segment in contents:
|
||||
speaker_id = segment.get("speakerId", "")
|
||||
if speaker_id not in speaker_to_name:
|
||||
if self.hide_user_info:
|
||||
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
|
||||
else:
|
||||
speaker_to_name[speaker_id] = id_to_name_map.get(
|
||||
speaker_id, "Unknown"
|
||||
)
|
||||
|
||||
speaker_name = speaker_to_name[speaker_id]
|
||||
|
||||
sentences = segment.get("sentences", {})
|
||||
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
|
||||
transcript_text += f"{speaker_name}: {monolog}\n\n"
|
||||
|
||||
yield Document(
|
||||
id=call_id,
|
||||
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
|
||||
source=DocumentSource.GONG,
|
||||
semantic_identifier=call_title or "Untitled",
|
||||
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={"client": call_metadata.get("system")},
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
combined = (
|
||||
f"{credentials['gong_access_key']}:{credentials['gong_access_key_secret']}"
|
||||
)
|
||||
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
if self.auth_token_basic is None:
|
||||
raise ConnectorMissingCredentialError("Gong")
|
||||
|
||||
self._session.headers.update(
|
||||
{"Authorization": f"Basic {self.auth_token_basic}"}
|
||||
)
|
||||
return None
|
||||
|
||||
def build_dummy_checkpoint(self) -> GongConnectorCheckpoint:
|
||||
return GongConnectorCheckpoint(has_more=True)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> GongConnectorCheckpoint:
|
||||
return GongConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GongConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GongConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# Step 1: Resolve workspace IDs on first call
|
||||
if checkpoint.workspace_ids is None:
|
||||
checkpoint.workspace_ids = self._resolve_workspace_ids()
|
||||
checkpoint.time_range = self._compute_time_range(start, end)
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
workspace_ids = checkpoint.workspace_ids
|
||||
|
||||
# If we've exhausted all workspaces, we're done
|
||||
if checkpoint.workspace_index >= len(workspace_ids):
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
# Use cached time range, falling back to computation if not cached
|
||||
start_time, end_time = checkpoint.time_range or self._compute_time_range(
|
||||
start, end
|
||||
)
|
||||
logger.info(
|
||||
f"Fetching Gong calls between {start_time} and {end_time} "
|
||||
f"(workspace {checkpoint.workspace_index + 1}/{len(workspace_ids)})"
|
||||
)
|
||||
|
||||
workspace_id = workspace_ids[checkpoint.workspace_index]
|
||||
|
||||
# Step 2: Fetch one page of transcripts
|
||||
try:
|
||||
page = self._fetch_transcript_page(
|
||||
start_datetime=start_time,
|
||||
end_datetime=end_time,
|
||||
workspace_id=workspace_id,
|
||||
cursor=checkpoint.cursor,
|
||||
)
|
||||
except _CursorExpiredError:
|
||||
# Gong cursors TTL ~1h from first request in the sequence. If the
|
||||
# checkpoint paused long enough for the cursor to expire, restart
|
||||
# the current workspace from the beginning of the time range.
|
||||
# Document upserts are idempotent (keyed by call_id) so
|
||||
# reprocessing is safe.
|
||||
logger.warning(
|
||||
f"Gong pagination cursor expired for workspace "
|
||||
f"{checkpoint.workspace_index + 1}/{len(workspace_ids)}; "
|
||||
f"restarting workspace from beginning of time range."
|
||||
)
|
||||
checkpoint.cursor = None
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
# Step 3: Process transcripts into documents
|
||||
if page.transcripts:
|
||||
yield from self._process_transcripts(page.transcripts)
|
||||
|
||||
# Step 4: Update checkpoint state
|
||||
if page.next_cursor:
|
||||
# More pages in this workspace
|
||||
checkpoint.cursor = page.next_cursor
|
||||
checkpoint.has_more = True
|
||||
else:
|
||||
# This workspace is exhausted — advance to next
|
||||
checkpoint.workspace_index += 1
|
||||
checkpoint.cursor = None
|
||||
checkpoint.has_more = checkpoint.workspace_index < len(workspace_ids)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -412,5 +502,13 @@ if __name__ == "__main__":
|
||||
}
|
||||
)
|
||||
|
||||
latest_docs = connector.load_from_state()
|
||||
print(next(latest_docs))
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while checkpoint.has_more:
|
||||
doc_generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(doc_generator)
|
||||
print(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
print(f"Checkpoint: {checkpoint}")
|
||||
|
||||
@@ -502,6 +502,9 @@ class GoogleDriveConnector(
|
||||
files: list[RetrievedDriveFile],
|
||||
seen_hierarchy_node_raw_ids: ThreadSafeSet[str],
|
||||
fully_walked_hierarchy_node_raw_ids: ThreadSafeSet[str],
|
||||
failed_folder_ids_by_email: (
|
||||
ThreadSafeDict[str, ThreadSafeSet[str]] | None
|
||||
) = None,
|
||||
permission_sync_context: PermissionSyncContext | None = None,
|
||||
add_prefix: bool = False,
|
||||
) -> list[HierarchyNode]:
|
||||
@@ -525,6 +528,9 @@ class GoogleDriveConnector(
|
||||
seen_hierarchy_node_raw_ids: Set of already-yielded node IDs (modified in place)
|
||||
fully_walked_hierarchy_node_raw_ids: Set of node IDs where the walk to root
|
||||
succeeded (modified in place)
|
||||
failed_folder_ids_by_email: Map of email → folder IDs where that email
|
||||
previously confirmed no accessible parent. Skips the API call if the same
|
||||
(folder, email) is encountered again (modified in place).
|
||||
permission_sync_context: If provided, permissions will be fetched for hierarchy nodes.
|
||||
Contains google_domain and primary_admin_email needed for permission syncing.
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
@@ -569,7 +575,7 @@ class GoogleDriveConnector(
|
||||
|
||||
# Fetch folder metadata
|
||||
folder = self._get_folder_metadata(
|
||||
current_id, file.user_email, field_type
|
||||
current_id, file.user_email, field_type, failed_folder_ids_by_email
|
||||
)
|
||||
if not folder:
|
||||
# Can't access this folder - stop climbing
|
||||
@@ -653,7 +659,13 @@ class GoogleDriveConnector(
|
||||
return new_nodes
|
||||
|
||||
def _get_folder_metadata(
|
||||
self, folder_id: str, retriever_email: str, field_type: DriveFileFieldType
|
||||
self,
|
||||
folder_id: str,
|
||||
retriever_email: str,
|
||||
field_type: DriveFileFieldType,
|
||||
failed_folder_ids_by_email: (
|
||||
ThreadSafeDict[str, ThreadSafeSet[str]] | None
|
||||
) = None,
|
||||
) -> GoogleDriveFileType | None:
|
||||
"""
|
||||
Fetch metadata for a folder by ID.
|
||||
@@ -667,6 +679,17 @@ class GoogleDriveConnector(
|
||||
|
||||
# Use a set to deduplicate if retriever_email == primary_admin_email
|
||||
for email in {retriever_email, self.primary_admin_email}:
|
||||
failed_ids = (
|
||||
failed_folder_ids_by_email.get(email)
|
||||
if failed_folder_ids_by_email
|
||||
else None
|
||||
)
|
||||
if failed_ids and folder_id in failed_ids:
|
||||
logger.debug(
|
||||
f"Skipping folder {folder_id} using {email} (previously confirmed no parents)"
|
||||
)
|
||||
continue
|
||||
|
||||
service = get_drive_service(self.creds, email)
|
||||
folder = get_folder_metadata(service, folder_id, field_type)
|
||||
|
||||
@@ -682,6 +705,10 @@ class GoogleDriveConnector(
|
||||
|
||||
# Folder has no parents - could be a root OR user lacks access to parent
|
||||
# Keep this as a fallback but try admin to see if they can see parents
|
||||
if failed_folder_ids_by_email is not None:
|
||||
failed_folder_ids_by_email.setdefault(email, ThreadSafeSet()).add(
|
||||
folder_id
|
||||
)
|
||||
if best_folder is None:
|
||||
best_folder = folder
|
||||
logger.debug(
|
||||
@@ -1090,6 +1117,13 @@ class GoogleDriveConnector(
|
||||
]
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
|
||||
|
||||
# Free per-user cache entries now that this batch is done.
|
||||
# Skip the admin email — it is shared across all user batches and must
|
||||
# persist for the duration of the run.
|
||||
for email in non_completed_org_emails:
|
||||
if email != self.primary_admin_email:
|
||||
checkpoint.failed_folder_ids_by_email.pop(email, None)
|
||||
|
||||
# if there are more emails to process, don't mark as complete
|
||||
if not email_batch_takes_us_to_completion:
|
||||
return
|
||||
@@ -1546,6 +1580,7 @@ class GoogleDriveConnector(
|
||||
files=files_batch,
|
||||
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
|
||||
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
|
||||
failed_folder_ids_by_email=checkpoint.failed_folder_ids_by_email,
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
@@ -1782,6 +1817,7 @@ class GoogleDriveConnector(
|
||||
files=files_batch,
|
||||
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
|
||||
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
|
||||
failed_folder_ids_by_email=checkpoint.failed_folder_ids_by_email,
|
||||
permission_sync_context=permission_sync_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -379,10 +379,20 @@ def _download_and_extract_sections_basic(
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
or is_tabular_file(file_name)
|
||||
):
|
||||
# Google Drive doesn't enforce file extensions, so the filename may not
|
||||
# end in .xlsx even when the mime type says it's one. Synthesize the
|
||||
# extension so tabular_file_to_sections dispatches correctly.
|
||||
tabular_file_name = file_name
|
||||
if (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
and not is_tabular_file(file_name)
|
||||
):
|
||||
tabular_file_name = f"{file_name}.xlsx"
|
||||
return list(
|
||||
tabular_file_to_sections(
|
||||
io.BytesIO(response_call()),
|
||||
file_name=file_name,
|
||||
file_name=tabular_file_name,
|
||||
link=link,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -167,6 +167,13 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
default_factory=ThreadSafeSet
|
||||
)
|
||||
|
||||
# Maps email → set of IDs of folders where that email confirmed no accessible parent.
|
||||
# Avoids redundant API calls when the same (folder, email) pair is
|
||||
# encountered again within the same retrieval run.
|
||||
failed_folder_ids_by_email: ThreadSafeDict[str, ThreadSafeSet[str]] = Field(
|
||||
default_factory=ThreadSafeDict
|
||||
)
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(
|
||||
self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any
|
||||
@@ -211,3 +218,25 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
if isinstance(v, list):
|
||||
return ThreadSafeSet(set(v)) # ty: ignore[invalid-return-type]
|
||||
return ThreadSafeSet()
|
||||
|
||||
@field_serializer("failed_folder_ids_by_email")
|
||||
def serialize_failed_folder_ids_by_email(
|
||||
self,
|
||||
failed_folder_ids_by_email: ThreadSafeDict[str, ThreadSafeSet[str]],
|
||||
_info: Any,
|
||||
) -> dict[str, set[str]]:
|
||||
return {
|
||||
k: inner.copy() for k, inner in failed_folder_ids_by_email.copy().items()
|
||||
}
|
||||
|
||||
@field_validator("failed_folder_ids_by_email", mode="before")
|
||||
def validate_failed_folder_ids_by_email(
|
||||
cls, v: Any
|
||||
) -> ThreadSafeDict[str, ThreadSafeSet[str]]:
|
||||
if isinstance(v, ThreadSafeDict):
|
||||
return v
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(
|
||||
{k: ThreadSafeSet(set(vals)) for k, vals in v.items()}
|
||||
)
|
||||
return ThreadSafeDict()
|
||||
|
||||
@@ -62,17 +62,19 @@ def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
def extract_text_from_adf(adf: dict | None) -> str:
|
||||
"""Extracts plain text from Atlassian Document Format:
|
||||
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
|
||||
|
||||
WARNING: This function is incomplete and will e.g. skip lists!
|
||||
"""
|
||||
# TODO: complete this function
|
||||
texts = []
|
||||
if adf is not None and "content" in adf:
|
||||
for block in adf["content"]:
|
||||
if "content" in block:
|
||||
for item in block["content"]:
|
||||
if item["type"] == "text":
|
||||
texts.append(item["text"])
|
||||
texts: list[str] = []
|
||||
|
||||
def _extract(node: dict) -> None:
|
||||
if node.get("type") == "text":
|
||||
text = node.get("text", "")
|
||||
if text:
|
||||
texts.append(text)
|
||||
for child in node.get("content", []):
|
||||
_extract(child)
|
||||
|
||||
if adf is not None:
|
||||
_extract(adf)
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
|
||||
@@ -1958,8 +1958,7 @@ class SharepointConnector(
|
||||
self._graph_client = GraphClient(
|
||||
_acquire_token_for_graph, environment=self._azure_environment
|
||||
)
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
self.sp_tenant_domain = self._resolve_tenant_domain()
|
||||
return None
|
||||
|
||||
def _get_drive_names_for_site(self, site_url: str) -> list[str]:
|
||||
|
||||
@@ -32,11 +32,16 @@ from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from onyx.utils.web_content import extract_pdf_text
|
||||
@@ -438,7 +443,7 @@ def _handle_cookies(context: BrowserContext, url: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
class WebConnector(LoadConnector):
|
||||
class WebConnector(LoadConnector, SlimConnector):
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(
|
||||
@@ -493,8 +498,14 @@ class WebConnector(LoadConnector):
|
||||
index: int,
|
||||
initial_url: str,
|
||||
session_ctx: ScrapeSessionContext,
|
||||
slim: bool = False,
|
||||
) -> ScrapeResult:
|
||||
"""Returns a ScrapeResult object with a doc and retry flag."""
|
||||
"""Returns a ScrapeResult object with a doc and retry flag.
|
||||
|
||||
When slim=True, skips all content extraction and render waits.
|
||||
result.url is set to the resolved URL; result.doc is always None.
|
||||
Link discovery via <a href> tags is performed in both modes.
|
||||
"""
|
||||
|
||||
if session_ctx.playwright is None:
|
||||
raise RuntimeError("scrape_context.playwright is None")
|
||||
@@ -516,6 +527,17 @@ class WebConnector(LoadConnector):
|
||||
|
||||
if is_pdf:
|
||||
# PDF files are not checked for links
|
||||
if slim:
|
||||
# No content needed; record the URL via a minimal Document
|
||||
result.doc = Document(
|
||||
id=initial_url,
|
||||
sections=[],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url,
|
||||
metadata={},
|
||||
)
|
||||
return result
|
||||
|
||||
response = requests.get(initial_url, headers=DEFAULT_HEADERS)
|
||||
page_text, metadata = extract_pdf_text(response.content)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
@@ -546,17 +568,21 @@ class WebConnector(LoadConnector):
|
||||
timeout=30000, # 30 seconds
|
||||
wait_until="commit", # Wait for navigation to commit
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Wait for network activity to settle so SPAs that fetch content
|
||||
# asynchronously after the initial JS bundle have time to render.
|
||||
try:
|
||||
# A bit of extra time to account for long-polling, websockets, etc.
|
||||
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
|
||||
except TimeoutError:
|
||||
pass
|
||||
if not slim:
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
|
||||
|
||||
# Wait for network activity to settle so SPAs that fetch content
|
||||
# asynchronously after the initial JS bundle have time to render.
|
||||
try:
|
||||
# A bit of extra time to account for long-polling, websockets, etc.
|
||||
page.wait_for_load_state(
|
||||
"networkidle", timeout=PAGE_RENDER_TIMEOUT_MS
|
||||
)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified") if page_response else None
|
||||
@@ -576,7 +602,7 @@ class WebConnector(LoadConnector):
|
||||
session_ctx.visited_links.add(initial_url)
|
||||
|
||||
# If we got here, the request was successful
|
||||
if self.scroll_before_scraping:
|
||||
if not slim and self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
@@ -615,6 +641,16 @@ class WebConnector(LoadConnector):
|
||||
result.retry = True
|
||||
return result
|
||||
|
||||
if slim:
|
||||
result.doc = Document(
|
||||
id=initial_url,
|
||||
sections=[],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url,
|
||||
metadata={},
|
||||
)
|
||||
return result
|
||||
|
||||
# after this point, we don't need the caller to retry
|
||||
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
|
||||
|
||||
@@ -742,6 +778,93 @@ class WebConnector(LoadConnector):
|
||||
|
||||
session_ctx.stop()
|
||||
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
|
||||
callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""Yields SlimDocuments for all pages reachable from the configured URLs.
|
||||
|
||||
Uses the same Playwright-based crawl as load_from_state for accurate
|
||||
JS-rendered link discovery, but skips all content extraction and render
|
||||
waits. The start/end parameters are ignored — WEB connector has no
|
||||
incremental path.
|
||||
"""
|
||||
if not self.to_visit_list:
|
||||
raise ValueError("No URLs to visit")
|
||||
|
||||
base_url = self.to_visit_list[0]
|
||||
check_internet_connection(base_url)
|
||||
|
||||
session_ctx = ScrapeSessionContext(base_url, self.to_visit_list)
|
||||
session_ctx.initialize()
|
||||
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
|
||||
while session_ctx.to_visit:
|
||||
initial_url = session_ctx.to_visit.pop()
|
||||
if initial_url in session_ctx.visited_links:
|
||||
continue
|
||||
session_ctx.visited_links.add(initial_url)
|
||||
|
||||
try:
|
||||
protected_url_check(initial_url)
|
||||
except Exception as e:
|
||||
session_ctx.last_error = f"Invalid URL {initial_url} due to {e}"
|
||||
logger.warning(session_ctx.last_error)
|
||||
continue
|
||||
|
||||
index = len(session_ctx.visited_links)
|
||||
logger.info(f"{index}: Slim-visiting {initial_url}")
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < self.MAX_RETRIES:
|
||||
if retry_count > 0:
|
||||
delay = min(2**retry_count + random.uniform(0, 1), 10)
|
||||
logger.info(
|
||||
f"Retry {retry_count}/{self.MAX_RETRIES} for {initial_url} after {delay:.2f}s delay"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
try:
|
||||
result = self._do_scrape(index, initial_url, session_ctx, slim=True)
|
||||
if result.retry:
|
||||
continue
|
||||
|
||||
if result.doc:
|
||||
slim_batch.append(SlimDocument(id=result.doc.id))
|
||||
except Exception as e:
|
||||
session_ctx.last_error = (
|
||||
f"Failed to slim-fetch '{initial_url}': {e}"
|
||||
)
|
||||
logger.exception(session_ctx.last_error)
|
||||
session_ctx.initialize()
|
||||
continue
|
||||
finally:
|
||||
retry_count += 1
|
||||
|
||||
break
|
||||
|
||||
if len(slim_batch) >= self.batch_size:
|
||||
session_ctx.initialize()
|
||||
session_ctx.at_least_one_doc = True
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
|
||||
if slim_batch:
|
||||
session_ctx.stop()
|
||||
session_ctx.at_least_one_doc = True
|
||||
yield slim_batch
|
||||
|
||||
if not session_ctx.at_least_one_doc:
|
||||
session_ctx.stop()
|
||||
if session_ctx.last_error:
|
||||
raise RuntimeError(session_ctx.last_error)
|
||||
raise RuntimeError("No valid pages found.")
|
||||
|
||||
session_ctx.stop()
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Make sure we have at least one valid URL to check
|
||||
if not self.to_visit_list:
|
||||
|
||||
@@ -81,9 +81,7 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
# zuliprc file. This reverts them back to newlines.
|
||||
contents_spaces_to_newlines = contents.replace(" ", "\n")
|
||||
# create a temporary zuliprc file
|
||||
tempdir = tempfile.tempdir
|
||||
if tempdir is None:
|
||||
raise Exception("Could not determine tempfile directory")
|
||||
tempdir = tempfile.gettempdir()
|
||||
config_file = os.path.join(tempdir, f"zuliprc-{self.realm_name}")
|
||||
with open(config_file, "w") as f:
|
||||
f.write(contents_spaces_to_newlines)
|
||||
|
||||
@@ -244,13 +244,21 @@ def fetch_latest_index_attempts_by_status(
|
||||
return query.all()
|
||||
|
||||
|
||||
_INTERNAL_ONLY_SOURCES = {
|
||||
# Used by the ingestion API, not a user-created connector.
|
||||
DocumentSource.INGESTION_API,
|
||||
# Backs the user library / build feature, not a connector users filter by.
|
||||
DocumentSource.CRAFT_FILE,
|
||||
}
|
||||
|
||||
|
||||
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
distinct_sources = db_session.query(Connector.source).distinct().all()
|
||||
|
||||
sources = [
|
||||
source[0]
|
||||
for source in distinct_sources
|
||||
if source[0] != DocumentSource.INGESTION_API
|
||||
if source[0] not in _INTERNAL_ONLY_SOURCES
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
@@ -412,7 +413,11 @@ def get_opensearch_retrieval_state(
|
||||
|
||||
If the tenant migration record is not found, defaults to
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
|
||||
|
||||
If ONYX_DISABLE_VESPA is True, always returns True.
|
||||
"""
|
||||
if ONYX_DISABLE_VESPA:
|
||||
return True
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
|
||||
from onyx.document_index.disabled import DisabledDocumentIndex
|
||||
@@ -48,6 +49,11 @@ def get_default_document_index(
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not opensearch_retrieval_enabled:
|
||||
raise ValueError(
|
||||
"BUG: ONYX_DISABLE_VESPA is set but opensearch_retrieval_enabled is not set."
|
||||
)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -119,21 +125,32 @@ def get_all_document_indices(
|
||||
)
|
||||
]
|
||||
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name if secondary_search_settings else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
result: list[DocumentIndex] = []
|
||||
|
||||
if ONYX_DISABLE_VESPA:
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
raise ValueError(
|
||||
"ONYX_DISABLE_VESPA is set but ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is not set."
|
||||
)
|
||||
else:
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result.append(vespa_document_index)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
@@ -169,7 +186,6 @@ def get_all_document_indices(
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result: list[DocumentIndex] = [vespa_document_index]
|
||||
if opensearch_document_index:
|
||||
result.append(opensearch_document_index)
|
||||
|
||||
return result
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.configs.app_configs import OPENSEARCH_USE_SSL
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import OpenSearchSearchType
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -132,7 +133,7 @@ class OpenSearchClient(AbstractContextManager):
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
use_ssl: bool = OPENSEARCH_USE_SSL,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
@@ -302,7 +303,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
use_ssl: bool = OPENSEARCH_USE_SSL,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
|
||||
@@ -23,6 +23,7 @@ import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -47,6 +48,7 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
"File contains no valid workbook part",
|
||||
"Unable to read workbook: could not read stylesheet from None",
|
||||
"Colors must be aRGB hex values",
|
||||
"Max value is",
|
||||
]
|
||||
|
||||
|
||||
@@ -191,6 +193,56 @@ def read_text_file(
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
|
||||
|
||||
Used to reject PDFs whose image count would OOM the user-file-processing
|
||||
worker during indexing. Returns a value > cap as a sentinel once the count
|
||||
exceeds the cap, so callers do not iterate thousands of image objects just
|
||||
to report a number. Returns 0 if the PDF cannot be parsed.
|
||||
|
||||
Owner-password-only PDFs (permission restrictions but no open password) are
|
||||
counted normally — they decrypt with an empty string. Truly password-locked
|
||||
PDFs are skipped (return 0) since we can't inspect them; the caller should
|
||||
ensure the password-protected check runs first.
|
||||
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
reader = PdfReader(file)
|
||||
if reader.is_encrypted:
|
||||
# Try empty password first (owner-password-only PDFs); give up if that fails.
|
||||
try:
|
||||
if reader.decrypt("") == 0:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
count = 0
|
||||
for page in reader.pages:
|
||||
for _ in page.images:
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in PDF", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
@@ -254,8 +306,27 @@ def read_pdf_file(
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
images_processed = 0
|
||||
cap_reached = False
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
if cap_reached:
|
||||
break
|
||||
for image_file_object in page.images:
|
||||
if images_processed >= image_cap:
|
||||
# Defense-in-depth backstop. Upload-time validation
|
||||
# should have rejected files exceeding the cap, but
|
||||
# we also break here so a single oversized file can
|
||||
# never pin a worker.
|
||||
logger.warning(
|
||||
"PDF embedded image cap reached (%d). "
|
||||
"Skipping remaining images on page %d and beyond.",
|
||||
image_cap,
|
||||
page_num + 1,
|
||||
)
|
||||
cap_reached = True
|
||||
break
|
||||
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
@@ -268,6 +339,7 @@ def read_pdf_file(
|
||||
image_callback(img_bytes, image_name)
|
||||
else:
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
images_processed += 1
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
from collections import Counter
|
||||
from datetime import date
|
||||
from itertools import zip_longest
|
||||
|
||||
from dateutil.parser import parse as parse_dt
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.utils.csv_utils import ParsedRow
|
||||
|
||||
|
||||
CATEGORICAL_DISTINCT_THRESHOLD = 20
|
||||
ID_NAME_TOKENS = {"id", "uuid", "uid", "guid", "key"}
|
||||
|
||||
|
||||
class SheetAnalysis(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
row_count: int
|
||||
num_cols: int
|
||||
numeric_cols: list[int] = Field(default_factory=list)
|
||||
categorical_cols: list[int] = Field(default_factory=list)
|
||||
numeric_values: dict[int, list[float]] = Field(default_factory=dict)
|
||||
categorical_counts: dict[int, Counter[str]] = Field(default_factory=dict)
|
||||
id_col: int | None = None
|
||||
date_min: date | None = None
|
||||
date_max: date | None = None
|
||||
|
||||
@property
|
||||
def categorical_values(self) -> dict[int, list[str]]:
|
||||
return {ci: list(c.keys()) for ci, c in self.categorical_counts.items()}
|
||||
|
||||
|
||||
def analyze_sheet(headers: list[str], parsed_rows: list[ParsedRow]) -> SheetAnalysis:
|
||||
a = SheetAnalysis(row_count=len(parsed_rows), num_cols=len(headers))
|
||||
columns = zip_longest(*(pr.row for pr in parsed_rows), fillvalue="")
|
||||
for idx, (header, raw_values) in enumerate(zip(headers, columns)):
|
||||
values = [v.strip() for v in raw_values if v.strip()]
|
||||
if not values:
|
||||
continue
|
||||
|
||||
# Identifier: id-named column whose values are all unique. Detected
|
||||
# before classification so a numeric `id` column still gets flagged.
|
||||
distinct = set(values)
|
||||
if a.id_col is None and len(distinct) == len(values) and _is_id_name(header):
|
||||
a.id_col = idx
|
||||
|
||||
# Numeric: every value parses as a number.
|
||||
nums = _try_all_numeric(values)
|
||||
if nums is not None:
|
||||
a.numeric_cols.append(idx)
|
||||
a.numeric_values[idx] = nums
|
||||
continue
|
||||
|
||||
# Date: every value parses as a date — fold into the sheet-wide range.
|
||||
dates = _try_all_dates(values)
|
||||
if dates:
|
||||
dmin = min(dates)
|
||||
dmax = max(dates)
|
||||
a.date_min = dmin if a.date_min is None else min(a.date_min, dmin)
|
||||
a.date_max = dmax if a.date_max is None else max(a.date_max, dmax)
|
||||
continue
|
||||
|
||||
# Categorical: low-cardinality column — keep counts for samples + top values.
|
||||
if len(distinct) <= max(CATEGORICAL_DISTINCT_THRESHOLD, len(values) // 2):
|
||||
a.categorical_cols.append(idx)
|
||||
a.categorical_counts[idx] = Counter(values)
|
||||
return a
|
||||
|
||||
|
||||
def _try_all_numeric(values: list[str]) -> list[float] | None:
|
||||
parsed: list[float] = []
|
||||
for v in values:
|
||||
n = _parse_num(v)
|
||||
if n is None:
|
||||
return None
|
||||
parsed.append(n)
|
||||
return parsed
|
||||
|
||||
|
||||
def _parse_num(value: str) -> float | None:
|
||||
try:
|
||||
return float(value.replace(",", ""))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _try_all_dates(values: list[str]) -> list[date] | None:
|
||||
parsed: list[date] = []
|
||||
for v in values:
|
||||
d = _try_date(v)
|
||||
if d is None:
|
||||
return None
|
||||
parsed.append(d)
|
||||
return parsed
|
||||
|
||||
|
||||
def _try_date(value: str) -> date | None:
|
||||
if len(value) < 4 or not any(c in value for c in "-/T"):
|
||||
return None
|
||||
try:
|
||||
return parse_dt(value).date()
|
||||
except (ValueError, OverflowError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _is_id_name(name: str) -> bool:
|
||||
lowered = name.lower().strip().replace("-", "_")
|
||||
return lowered in ID_NAME_TOKENS or any(
|
||||
lowered.endswith(f"_{t}") for t in ID_NAME_TOKENS
|
||||
)
|
||||
@@ -1,51 +1,29 @@
|
||||
"""Per-section sheet descriptor chunk builder."""
|
||||
|
||||
from datetime import date
|
||||
from itertools import zip_longest
|
||||
|
||||
from dateutil.parser import parse as parse_dt
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.chunking.tabular_section_chunker.analysis import SheetAnalysis
|
||||
from onyx.indexing.chunking.tabular_section_chunker.util import label
|
||||
from onyx.indexing.chunking.tabular_section_chunker.util import pack_lines
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.utils.csv_utils import parse_csv_string
|
||||
from onyx.utils.csv_utils import ParsedRow
|
||||
from onyx.utils.csv_utils import read_csv_header
|
||||
|
||||
|
||||
MAX_NUMERIC_COLS = 12
|
||||
MAX_CATEGORICAL_COLS = 6
|
||||
MAX_CATEGORICAL_WITH_SAMPLES = 4
|
||||
MAX_DISTINCT_SAMPLES = 8
|
||||
CATEGORICAL_DISTINCT_THRESHOLD = 20
|
||||
ID_NAME_TOKENS = {"id", "uuid", "uid", "guid", "key"}
|
||||
|
||||
|
||||
class SheetAnalysis(BaseModel):
|
||||
row_count: int
|
||||
num_cols: int
|
||||
numeric_cols: list[int] = Field(default_factory=list)
|
||||
categorical_cols: list[int] = Field(default_factory=list)
|
||||
categorical_values: dict[int, list[str]] = Field(default_factory=dict)
|
||||
id_col: int | None = None
|
||||
date_min: date | None = None
|
||||
date_max: date | None = None
|
||||
|
||||
|
||||
def build_sheet_descriptor_chunks(
|
||||
section: Section,
|
||||
headers: list[str],
|
||||
analysis: SheetAnalysis,
|
||||
heading: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
"""Build sheet descriptor chunk(s) from a parsed CSV section.
|
||||
"""Build sheet descriptor chunk(s) from a pre-parsed sheet.
|
||||
|
||||
Output (lines joined by "\\n"; lines that overflow ``max_tokens`` on
|
||||
their own are skipped; ``section.heading`` is prepended to every
|
||||
emitted chunk so retrieval keeps sheet context after a split):
|
||||
their own are skipped; ``heading`` is prepended to every emitted
|
||||
chunk so retrieval keeps sheet context after a split):
|
||||
|
||||
{section.heading} # optional
|
||||
{heading} # optional
|
||||
Sheet overview.
|
||||
This sheet has {N} rows and {M} columns.
|
||||
Columns: {col1}, {col2}, ...
|
||||
@@ -55,25 +33,21 @@ def build_sheet_descriptor_chunks(
|
||||
Identifier column: {col}. # optional
|
||||
Values seen in {col}: {v1}, {v2}, ... # optional, repeated
|
||||
"""
|
||||
text = section.text or ""
|
||||
parsed_rows = list(parse_csv_string(text))
|
||||
headers = parsed_rows[0].header if parsed_rows else read_csv_header(text)
|
||||
if not headers:
|
||||
return []
|
||||
|
||||
a = _analyze(headers, parsed_rows)
|
||||
lines = [
|
||||
_overview_line(a),
|
||||
_overview_line(analysis),
|
||||
_columns_line(headers),
|
||||
_time_range_line(a),
|
||||
_numeric_cols_line(headers, a),
|
||||
_categorical_cols_line(headers, a),
|
||||
_id_col_line(headers, a),
|
||||
_values_seen_line(headers, a),
|
||||
_time_range_line(analysis),
|
||||
_numeric_cols_line(headers, analysis),
|
||||
_categorical_cols_line(headers, analysis),
|
||||
_id_col_line(headers, analysis),
|
||||
_values_seen_line(headers, analysis),
|
||||
]
|
||||
return _pack_lines(
|
||||
return pack_lines(
|
||||
[line for line in lines if line],
|
||||
prefix=section.heading or "",
|
||||
prefix=heading,
|
||||
tokenizer=tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -87,7 +61,7 @@ def _overview_line(a: SheetAnalysis) -> str:
|
||||
|
||||
|
||||
def _columns_line(headers: list[str]) -> str:
|
||||
return "Columns: " + ", ".join(_label(h) for h in headers)
|
||||
return "Columns: " + ", ".join(label(h) for h in headers)
|
||||
|
||||
|
||||
def _time_range_line(a: SheetAnalysis) -> str:
|
||||
@@ -99,7 +73,7 @@ def _time_range_line(a: SheetAnalysis) -> str:
|
||||
def _numeric_cols_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
if not a.numeric_cols:
|
||||
return ""
|
||||
names = ", ".join(_label(headers[i]) for i in a.numeric_cols[:MAX_NUMERIC_COLS])
|
||||
names = ", ".join(label(headers[i]) for i in a.numeric_cols[:MAX_NUMERIC_COLS])
|
||||
return f"Numeric columns (aggregatable by sum, average, min, max): {names}"
|
||||
|
||||
|
||||
@@ -107,7 +81,7 @@ def _categorical_cols_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
if not a.categorical_cols:
|
||||
return ""
|
||||
names = ", ".join(
|
||||
_label(headers[i]) for i in a.categorical_cols[:MAX_CATEGORICAL_COLS]
|
||||
label(headers[i]) for i in a.categorical_cols[:MAX_CATEGORICAL_COLS]
|
||||
)
|
||||
return f"Categorical columns (groupable, can be counted by value): {names}"
|
||||
|
||||
@@ -115,7 +89,7 @@ def _categorical_cols_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
def _id_col_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
if a.id_col is None:
|
||||
return ""
|
||||
return f"Identifier column: {_label(headers[a.id_col])}."
|
||||
return f"Identifier column: {label(headers[a.id_col])}."
|
||||
|
||||
|
||||
def _values_seen_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
@@ -123,106 +97,5 @@ def _values_seen_line(headers: list[str], a: SheetAnalysis) -> str:
|
||||
for ci in a.categorical_cols[:MAX_CATEGORICAL_WITH_SAMPLES]:
|
||||
sample = sorted(a.categorical_values.get(ci, []))[:MAX_DISTINCT_SAMPLES]
|
||||
if sample:
|
||||
rows.append(f"Values seen in {_label(headers[ci])}: " + ", ".join(sample))
|
||||
rows.append(f"Values seen in {label(headers[ci])}: " + ", ".join(sample))
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def _label(name: str) -> str:
|
||||
return f"{name} ({name.replace('_', ' ')})" if "_" in name else name
|
||||
|
||||
|
||||
def _is_numeric(value: str) -> bool:
|
||||
try:
|
||||
float(value.replace(",", ""))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _try_date(value: str) -> date | None:
|
||||
if len(value) < 4 or not any(c in value for c in "-/T"):
|
||||
return None
|
||||
try:
|
||||
return parse_dt(value).date()
|
||||
except (ValueError, OverflowError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _is_id_name(name: str) -> bool:
|
||||
lowered = name.lower().strip().replace("-", "_")
|
||||
return lowered in ID_NAME_TOKENS or any(
|
||||
lowered.endswith(f"_{t}") for t in ID_NAME_TOKENS
|
||||
)
|
||||
|
||||
|
||||
def _analyze(headers: list[str], parsed_rows: list[ParsedRow]) -> SheetAnalysis:
|
||||
a = SheetAnalysis(row_count=len(parsed_rows), num_cols=len(headers))
|
||||
columns = zip_longest(*(pr.row for pr in parsed_rows), fillvalue="")
|
||||
for idx, (header, raw_values) in enumerate(zip(headers, columns)):
|
||||
# Pull the column's non-empty values; skip if the column is blank.
|
||||
values = [v.strip() for v in raw_values if v.strip()]
|
||||
if not values:
|
||||
continue
|
||||
|
||||
# Identifier: id-named column whose values are all unique. Detected
|
||||
# before classification so a numeric `id` column still gets flagged.
|
||||
distinct = set(values)
|
||||
if a.id_col is None and len(distinct) == len(values) and _is_id_name(header):
|
||||
a.id_col = idx
|
||||
|
||||
# Numeric: every value parses as a number.
|
||||
if all(_is_numeric(v) for v in values):
|
||||
a.numeric_cols.append(idx)
|
||||
continue
|
||||
|
||||
# Date: every value parses as a date — fold into the sheet-wide range.
|
||||
dates = [_try_date(v) for v in values]
|
||||
if all(d is not None for d in dates):
|
||||
dmin = min(filter(None, dates))
|
||||
dmax = max(filter(None, dates))
|
||||
a.date_min = dmin if a.date_min is None else min(a.date_min, dmin)
|
||||
a.date_max = dmax if a.date_max is None else max(a.date_max, dmax)
|
||||
continue
|
||||
|
||||
# Categorical: low-cardinality column — keep distinct values for samples.
|
||||
if len(distinct) <= max(CATEGORICAL_DISTINCT_THRESHOLD, len(values) // 2):
|
||||
a.categorical_cols.append(idx)
|
||||
a.categorical_values[idx] = list(distinct)
|
||||
return a
|
||||
|
||||
|
||||
def _pack_lines(
|
||||
lines: list[str],
|
||||
prefix: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
"""Greedily pack lines into chunks ≤ max_tokens. Lines that on
|
||||
their own exceed max_tokens (after accounting for the prefix) are
|
||||
skipped. ``prefix`` is prepended to every emitted chunk."""
|
||||
prefix_tokens = count_tokens(prefix, tokenizer) + 1 if prefix else 0
|
||||
budget = max_tokens - prefix_tokens
|
||||
|
||||
chunks: list[str] = []
|
||||
current: list[str] = []
|
||||
current_tokens = 0
|
||||
for line in lines:
|
||||
line_tokens = count_tokens(line, tokenizer)
|
||||
if line_tokens > budget:
|
||||
continue
|
||||
sep = 1 if current else 0
|
||||
if current_tokens + sep + line_tokens > budget:
|
||||
chunks.append(_join_with_prefix(current, prefix))
|
||||
current = [line]
|
||||
current_tokens = line_tokens
|
||||
else:
|
||||
current.append(line)
|
||||
current_tokens += sep + line_tokens
|
||||
if current:
|
||||
chunks.append(_join_with_prefix(current, prefix))
|
||||
return chunks
|
||||
|
||||
|
||||
def _join_with_prefix(lines: list[str], prefix: str) -> str:
|
||||
body = "\n".join(lines)
|
||||
return f"{prefix}\n{body}" if prefix else body
|
||||
|
||||
@@ -7,14 +7,19 @@ from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.section_chunker import ChunkPayload
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunker
|
||||
from onyx.indexing.chunking.section_chunker import SectionChunkerOutput
|
||||
from onyx.indexing.chunking.tabular_section_chunker.analysis import analyze_sheet
|
||||
from onyx.indexing.chunking.tabular_section_chunker.sheet_descriptor import (
|
||||
build_sheet_descriptor_chunks,
|
||||
)
|
||||
from onyx.indexing.chunking.tabular_section_chunker.total_descriptor import (
|
||||
build_total_descriptor_chunks,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
from onyx.natural_language_processing.utils import split_text_by_tokens
|
||||
from onyx.utils.csv_utils import parse_csv_string
|
||||
from onyx.utils.csv_utils import ParsedRow
|
||||
from onyx.utils.csv_utils import read_csv_header
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -230,24 +235,38 @@ class TabularChunker(SectionChunker):
|
||||
) -> SectionChunkerOutput:
|
||||
payloads = accumulator.flush_to_list()
|
||||
|
||||
parsed_rows = list(parse_csv_string(section.text or ""))
|
||||
sheet_header = section.heading or ""
|
||||
text = section.text or ""
|
||||
parsed_rows = list(parse_csv_string(text))
|
||||
headers = parsed_rows[0].header if parsed_rows else read_csv_header(text)
|
||||
heading = section.heading or ""
|
||||
|
||||
chunk_texts: list[str] = []
|
||||
if parsed_rows:
|
||||
chunk_texts.extend(
|
||||
parse_to_chunks(
|
||||
rows=parsed_rows,
|
||||
sheet_header=sheet_header,
|
||||
sheet_header=heading,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=content_token_limit,
|
||||
)
|
||||
)
|
||||
|
||||
if not self.ignore_metadata_chunks:
|
||||
if not self.ignore_metadata_chunks and headers:
|
||||
analysis = analyze_sheet(headers, parsed_rows)
|
||||
chunk_texts.extend(
|
||||
build_sheet_descriptor_chunks(
|
||||
section=section,
|
||||
headers=headers,
|
||||
analysis=analysis,
|
||||
heading=heading,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=content_token_limit,
|
||||
)
|
||||
)
|
||||
chunk_texts.extend(
|
||||
build_total_descriptor_chunks(
|
||||
headers=headers,
|
||||
analysis=analysis,
|
||||
heading=heading,
|
||||
tokenizer=self.tokenizer,
|
||||
max_tokens=content_token_limit,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from collections import Counter
|
||||
|
||||
from onyx.indexing.chunking.tabular_section_chunker.analysis import SheetAnalysis
|
||||
from onyx.indexing.chunking.tabular_section_chunker.util import label
|
||||
from onyx.indexing.chunking.tabular_section_chunker.util import pack_lines
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
|
||||
TOTALS_HEADER = (
|
||||
"Totals and overall aggregates across all rows. This sheet can answer "
|
||||
"whole-dataset questions about total, overall, grand total, sum across "
|
||||
"all, average, combined, mean, minimum, maximum, and count of values."
|
||||
)
|
||||
|
||||
|
||||
def build_total_descriptor_chunks(
|
||||
headers: list[str],
|
||||
analysis: SheetAnalysis,
|
||||
heading: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
if analysis.row_count == 0:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
for idx in analysis.numeric_cols:
|
||||
lines.append(_numeric_totals_line(headers[idx], analysis.numeric_values[idx]))
|
||||
for idx in analysis.categorical_cols:
|
||||
line = _categorical_top_line(headers[idx], analysis.categorical_counts[idx])
|
||||
if line:
|
||||
lines.append(line)
|
||||
|
||||
# No meaningful information - leave early
|
||||
if not lines:
|
||||
return []
|
||||
|
||||
lines.append(f"Total row count: {analysis.row_count}.")
|
||||
|
||||
prefix = (f"{heading}\n" if heading else "") + TOTALS_HEADER
|
||||
return pack_lines(
|
||||
lines=lines,
|
||||
prefix=prefix,
|
||||
tokenizer=tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _numeric_totals_line(name: str, values: list[float]) -> str:
|
||||
total = sum(values)
|
||||
avg = total / len(values)
|
||||
return (
|
||||
f"Column {label(name)}: total (sum across all rows) = {_fmt(total)}, "
|
||||
f"average = {_fmt(avg)}, minimum = {_fmt(min(values))}, "
|
||||
f"maximum = {_fmt(max(values))}, count = {len(values)}."
|
||||
)
|
||||
|
||||
|
||||
def _categorical_top_line(name: str, counts: Counter[str]) -> str:
|
||||
top = counts.most_common(1)
|
||||
if not top:
|
||||
return ""
|
||||
val, n = top[0]
|
||||
return f"Column {label(name)} most frequent value: {val} ({n} occurrences)."
|
||||
|
||||
|
||||
def _fmt(num: float) -> str:
|
||||
if abs(num) < 1e15 and num == int(num):
|
||||
return str(int(num))
|
||||
return f"{num:.6g}"
|
||||
@@ -0,0 +1,48 @@
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import count_tokens
|
||||
|
||||
|
||||
def label(name: str) -> str:
|
||||
"""Render a column name with a space-substituted friendly alias in
|
||||
parens for underscored headers so retrieval matches either surface
|
||||
form (e.g. ``MTTR_hours`` → ``MTTR_hours (MTTR hours)``)."""
|
||||
return f"{name} ({name.replace('_', ' ')})" if "_" in name else name
|
||||
|
||||
|
||||
def pack_lines(
|
||||
lines: list[str],
|
||||
prefix: str,
|
||||
tokenizer: BaseTokenizer,
|
||||
max_tokens: int,
|
||||
) -> list[str]:
|
||||
"""Greedily pack ``lines`` into chunks ≤ ``max_tokens``, prepending
|
||||
``prefix`` (verbatim) to every emitted chunk. Lines whose own token
|
||||
count exceeds the post-prefix budget are skipped. Callers assemble
|
||||
the full prefix (heading, header text, etc.) before calling.
|
||||
"""
|
||||
prefix_tokens = count_tokens(prefix, tokenizer) + 1 if prefix else 0
|
||||
budget = max_tokens - prefix_tokens
|
||||
|
||||
chunks: list[str] = []
|
||||
current: list[str] = []
|
||||
current_tokens = 0
|
||||
for line in lines:
|
||||
line_tokens = count_tokens(line, tokenizer)
|
||||
if line_tokens > budget:
|
||||
continue
|
||||
sep = 1 if current else 0
|
||||
if current_tokens + sep + line_tokens > budget:
|
||||
chunks.append(_join_with_prefix(current, prefix))
|
||||
current = [line]
|
||||
current_tokens = line_tokens
|
||||
else:
|
||||
current.append(line)
|
||||
current_tokens += sep + line_tokens
|
||||
if current:
|
||||
chunks.append(_join_with_prefix(current, prefix))
|
||||
return chunks
|
||||
|
||||
|
||||
def _join_with_prefix(lines: list[str], prefix: str) -> str:
|
||||
body = "\n".join(lines)
|
||||
return f"{prefix}\n{body}" if prefix else body
|
||||
@@ -1516,6 +1516,10 @@
|
||||
"display_name": "Claude Opus 4.6",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-7": {
|
||||
"display_name": "Claude Opus 4.7",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"claude-opus-4-5-20251101": {
|
||||
"display_name": "Claude Opus 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -46,6 +46,15 @@ ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.HIGH: 4096,
|
||||
}
|
||||
|
||||
# Newer Anthropic models (Claude Opus 4.7+) use adaptive thinking with
|
||||
# output_config.effort instead of thinking.type.enabled + budget_tokens.
|
||||
ANTHROPIC_ADAPTIVE_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.AUTO: "medium",
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_ADAPTIVE_REASONING_EFFORT
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
@@ -67,8 +68,13 @@ STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
_VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = (
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
)
|
||||
|
||||
# Anthropic models that require the adaptive thinking API (thinking.type.adaptive
|
||||
# + output_config.effort) instead of the legacy thinking.type.enabled + budget_tokens.
|
||||
_ANTHROPIC_ADAPTIVE_THINKING_MODELS = ("claude-opus-4-7",)
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
@@ -230,6 +236,14 @@ def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_uses_adaptive_thinking(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
adaptive_model in normalized_model_name
|
||||
for adaptive_model in _ANTHROPIC_ADAPTIVE_THINKING_MODELS
|
||||
)
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -509,10 +523,6 @@ class LitellmLLM(LLM):
|
||||
}
|
||||
|
||||
elif is_claude_model:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
@@ -520,24 +530,35 @@ class LitellmLLM(LLM):
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
has_tool_call_history = _prompt_contains_tool_call_history(prompt)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
if _anthropic_uses_adaptive_thinking(self.config.model_name):
|
||||
# Newer Anthropic models (Claude Opus 4.7+) reject
|
||||
# thinking.type.enabled — they require the adaptive
|
||||
# thinking config with output_config.effort.
|
||||
if not has_tool_call_history:
|
||||
optional_kwargs["thinking"] = {"type": "adaptive"}
|
||||
optional_kwargs["output_config"] = {
|
||||
"effort": ANTHROPIC_ADAPTIVE_REASONING_EFFORT[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
else:
|
||||
budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get(
|
||||
reasoning_effort
|
||||
)
|
||||
if budget_tokens is not None and not has_tool_call_history:
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
# Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now
|
||||
# It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool
|
||||
# call as compared to reducing the budget for reasoning.
|
||||
max_tokens = max(budget_tokens + 1, max_tokens)
|
||||
optional_kwargs["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens,
|
||||
}
|
||||
|
||||
# LiteLLM just does some mapping like this anyway but is incomplete for Anthropic
|
||||
optional_kwargs.pop("reasoning_effort", None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"version": "1.1",
|
||||
"updated_at": "2026-03-05T00:00:00Z",
|
||||
"version": "1.2",
|
||||
"updated_at": "2026-04-16T00:00:00Z",
|
||||
"providers": {
|
||||
"openai": {
|
||||
"default_model": { "name": "gpt-5.4" },
|
||||
@@ -10,8 +10,12 @@
|
||||
]
|
||||
},
|
||||
"anthropic": {
|
||||
"default_model": "claude-opus-4-6",
|
||||
"default_model": "claude-opus-4-7",
|
||||
"additional_visible_models": [
|
||||
{
|
||||
"name": "claude-opus-4-7",
|
||||
"display_name": "Claude Opus 4.7"
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4-6",
|
||||
"display_name": "Claude Opus 4.6"
|
||||
|
||||
@@ -19,9 +19,14 @@ from onyx.configs.app_configs import MCP_SERVER_CORS_ORIGINS
|
||||
from onyx.mcp_server.auth import OnyxTokenVerifier
|
||||
from onyx.mcp_server.utils import shutdown_http_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Initialize EE flag at module import so it's set regardless of the entry point
|
||||
# (python -m onyx.mcp_server_main, uvicorn onyx.mcp_server.api:mcp_app, etc.).
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
logger.info("Creating Onyx MCP Server...")
|
||||
|
||||
mcp_server = FastMCP(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Resource registrations for the Onyx MCP server."""
|
||||
|
||||
# Import resource modules so decorators execute when the package loads.
|
||||
from onyx.mcp_server.resources import document_sets # noqa: F401
|
||||
from onyx.mcp_server.resources import indexed_sources # noqa: F401
|
||||
|
||||
41
backend/onyx/mcp_server/resources/document_sets.py
Normal file
41
backend/onyx/mcp_server/resources/document_sets.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Resource exposing document sets available to the current user."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_accessible_document_sets
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@mcp_server.resource(
|
||||
"resource://document_sets",
|
||||
name="document_sets",
|
||||
description=(
|
||||
"Enumerate the Document Sets accessible to the current user. Use the "
|
||||
"returned `name` values with the `document_set_names` filter of the "
|
||||
"`search_indexed_documents` tool to scope searches to a specific set."
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def document_sets_resource() -> str:
|
||||
"""Return the list of document sets the user can filter searches by."""
|
||||
|
||||
access_token = require_access_token()
|
||||
|
||||
document_sets = sorted(
|
||||
await get_accessible_document_sets(access_token), key=lambda entry: entry.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Onyx MCP Server: document_sets resource returning %s entries",
|
||||
len(document_sets),
|
||||
)
|
||||
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps([entry.model_dump(mode="json") for entry in document_sets])
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import json
|
||||
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
),
|
||||
mime_type="application/json",
|
||||
)
|
||||
async def indexed_sources_resource() -> dict[str, Any]:
|
||||
async def indexed_sources_resource() -> str:
|
||||
"""Return the list of indexed source types for search filtering."""
|
||||
|
||||
access_token = require_access_token()
|
||||
@@ -33,6 +33,6 @@ async def indexed_sources_resource() -> dict[str, Any]:
|
||||
len(sources),
|
||||
)
|
||||
|
||||
return {
|
||||
"indexed_sources": sorted(sources),
|
||||
}
|
||||
# FastMCP 3.2+ requires str/bytes/list[ResourceContent] — it no longer
|
||||
# auto-serializes; serialize to JSON ourselves.
|
||||
return json.dumps(sorted(sources))
|
||||
|
||||
@@ -4,12 +4,23 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolRequest
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchToolRequest
|
||||
from onyx.server.features.web_search.models import WebSearchToolResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -17,6 +28,43 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# CE search falls through to the chat endpoint, which invokes an LLM — the
|
||||
# default 60s client timeout is not enough for a real RAG-backed response.
|
||||
_CE_SEARCH_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
async def _post_model(
|
||||
url: str,
|
||||
body: BaseModel,
|
||||
access_token: AccessToken,
|
||||
timeout: float | None = None,
|
||||
) -> httpx.Response:
|
||||
"""POST a Pydantic model as JSON to the Onyx backend."""
|
||||
return await get_http_client().post(
|
||||
url,
|
||||
content=body.model_dump_json(exclude_unset=True),
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token.token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=timeout if timeout is not None else httpx.USE_CLIENT_DEFAULT,
|
||||
)
|
||||
|
||||
|
||||
def _project_doc(doc: SearchDoc, content: str | None) -> dict[str, Any]:
|
||||
"""Project a backend search doc into the MCP wire shape.
|
||||
|
||||
Accepts SearchDocWithContent (EE) too since it extends SearchDoc.
|
||||
"""
|
||||
return {
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"content": content,
|
||||
"source_type": doc.source_type.value,
|
||||
"link": doc.link,
|
||||
"score": doc.score,
|
||||
}
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
@@ -36,6 +84,7 @@ def _extract_error_detail(response: httpx.Response) -> str:
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
source_types: list[str] | None = None,
|
||||
document_set_names: list[str] | None = None,
|
||||
time_cutoff: str | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
@@ -53,6 +102,10 @@ async def search_indexed_documents(
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
`document_set_names` restricts results to documents belonging to the named
|
||||
Document Sets — useful for scoping queries to a curated subset of the
|
||||
knowledge base (e.g. to isolate knowledge between agents). Use the
|
||||
`document_sets` resource to discover accessible set names.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
Example usage:
|
||||
@@ -60,15 +113,23 @@ async def search_indexed_documents(
|
||||
{
|
||||
"query": "What is the latest status of PROJ-1234 and what is the next development item?",
|
||||
"source_types": ["jira", "google_drive", "github"],
|
||||
"document_set_names": ["Engineering Wiki"],
|
||||
"time_cutoff": "2025-11-24T00:00:00Z",
|
||||
"limit": 10,
|
||||
}
|
||||
```
|
||||
"""
|
||||
logger.info(
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, limit={limit}"
|
||||
f"Onyx MCP Server: document search: query='{query}', sources={source_types}, "
|
||||
f"document_sets={document_set_names}, limit={limit}"
|
||||
)
|
||||
|
||||
# Normalize empty list inputs to None so downstream filter construction is
|
||||
# consistent — BaseFilters treats [] as "match zero" which differs from
|
||||
# "no filter" (None).
|
||||
source_types = source_types or None
|
||||
document_set_names = document_set_names or None
|
||||
|
||||
# Parse time_cutoff string to datetime if provided
|
||||
time_cutoff_dt: datetime | None = None
|
||||
if time_cutoff:
|
||||
@@ -81,9 +142,6 @@ async def search_indexed_documents(
|
||||
# Continue with no time_cutoff instead of returning an error
|
||||
time_cutoff_dt = None
|
||||
|
||||
# Initialize source_type_enums early to avoid UnboundLocalError
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
|
||||
# Get authenticated user from FastMCP's access token
|
||||
access_token = require_access_token()
|
||||
|
||||
@@ -117,6 +175,7 @@ async def search_indexed_documents(
|
||||
|
||||
# Convert source_types strings to DocumentSource enums if provided
|
||||
# Invalid values will be handled by the API server
|
||||
source_type_enums: list[DocumentSource] | None = None
|
||||
if source_types is not None:
|
||||
source_type_enums = []
|
||||
for src in source_types:
|
||||
@@ -127,83 +186,83 @@ async def search_indexed_documents(
|
||||
f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server"
|
||||
)
|
||||
|
||||
# Build filters dict only with non-None values
|
||||
filters: dict[str, Any] | None = None
|
||||
if source_type_enums or time_cutoff_dt:
|
||||
filters = {}
|
||||
if source_type_enums:
|
||||
filters["source_type"] = [src.value for src in source_type_enums]
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
filters: BaseFilters | None = None
|
||||
if source_type_enums or document_set_names or time_cutoff_dt:
|
||||
filters = BaseFilters(
|
||||
source_type=source_type_enums,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=time_cutoff_dt,
|
||||
)
|
||||
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
is_ee = global_version.is_ee_version()
|
||||
|
||||
search_request: dict[str, Any]
|
||||
request: BaseModel
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
# EE: use the dedicated search endpoint (no LLM invocation).
|
||||
# Lazy import so CE deployments that strip ee/ never load this module.
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
|
||||
request = SendSearchQueryRequest(
|
||||
search_query=query,
|
||||
filters=filters,
|
||||
num_docs_fed_to_llm_selection=limit,
|
||||
run_query_expansion=False,
|
||||
include_content=True,
|
||||
stream=False,
|
||||
)
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
request = SendMessageRequest(
|
||||
message=query,
|
||||
stream=False,
|
||||
chat_session_info=ChatSessionCreationRequest(),
|
||||
internal_search_filters=filters,
|
||||
)
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
endpoint,
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
request,
|
||||
access_token,
|
||||
timeout=None if is_ee else _CE_SEARCH_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": error_detail,
|
||||
}
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get(error_key):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get(error_key),
|
||||
"error": _extract_error_detail(response),
|
||||
}
|
||||
|
||||
documents = [
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
]
|
||||
if is_ee:
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
|
||||
ee_payload = SearchFullResponse.model_validate_json(response.content)
|
||||
if ee_payload.error:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ee_payload.error,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.content) for doc in ee_payload.search_docs
|
||||
]
|
||||
else:
|
||||
ce_payload = ChatFullResponse.model_validate_json(response.content)
|
||||
if ce_payload.error_msg:
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": ce_payload.error_msg,
|
||||
}
|
||||
documents = [
|
||||
_project_doc(doc, doc.blurb) for doc in ce_payload.top_documents
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
@@ -252,23 +311,20 @@ async def search_web(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
request_payload = {"queries": [query], "max_results": limit}
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite",
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
WebSearchToolRequest(queries=[query], max_results=limit),
|
||||
access_token,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"error": _extract_error_detail(response),
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
payload = WebSearchToolResponse.model_validate_json(response.content)
|
||||
return {
|
||||
"results": results,
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
"query": query,
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -305,21 +361,19 @@ async def open_urls(
|
||||
access_token = require_access_token()
|
||||
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
response = await _post_model(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls",
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
OpenUrlsToolRequest(urls=urls),
|
||||
access_token,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"error": _extract_error_detail(response),
|
||||
"results": [],
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
payload = OpenUrlsToolResponse.model_validate_json(response.content)
|
||||
return {
|
||||
"results": results,
|
||||
"results": [result.model_dump(mode="json") for result in payload.results],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: URL fetch error: {e}", exc_info=True)
|
||||
|
||||
@@ -5,10 +5,24 @@ from __future__ import annotations
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
from pydantic import BaseModel
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
|
||||
class DocumentSetEntry(BaseModel):
|
||||
"""Minimal document-set shape surfaced to MCP clients.
|
||||
|
||||
Projected from the backend's DocumentSetSummary to avoid coupling MCP to
|
||||
admin-only fields (cc-pair summaries, federated connectors, etc.).
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Shared HTTP client reused across requests
|
||||
@@ -84,3 +98,32 @@ async def get_indexed_sources(
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch indexed sources: {exc}") from exc
|
||||
|
||||
|
||||
_DOCUMENT_SET_ENTRIES_ADAPTER = TypeAdapter(list[DocumentSetEntry])
|
||||
|
||||
|
||||
async def get_accessible_document_sets(
|
||||
access_token: AccessToken,
|
||||
) -> list[DocumentSetEntry]:
|
||||
"""Fetch document sets accessible to the current user."""
|
||||
headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
try:
|
||||
response = await get_http_client().get(
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/document-set",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return _DOCUMENT_SET_ENTRIES_ADAPTER.validate_json(response.content)
|
||||
except (httpx.HTTPStatusError, httpx.RequestError, ValueError):
|
||||
logger.error(
|
||||
"Onyx MCP Server: Failed to fetch document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Onyx MCP Server: Unexpected error fetching document sets",
|
||||
exc_info=True,
|
||||
)
|
||||
raise RuntimeError(f"Failed to fetch document sets: {exc}") from exc
|
||||
|
||||
@@ -5,6 +5,7 @@ import uvicorn
|
||||
from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.tracing.setup import setup_tracing
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
@@ -18,6 +19,7 @@ def main() -> None:
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
setup_tracing()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -65,8 +65,9 @@ IMPORTANT: each call to this tool is independent. Variables from previous calls
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
To edit, restyle, or vary an existing image, pass its file_id in `reference_image_file_ids`. \
|
||||
File IDs come from `[attached image — file_id: <id>]` tags on user-attached images or from prior `generate_image` tool results — never invent one. \
|
||||
Leave `reference_image_file_ids` unset for a fresh generation.
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
@@ -126,6 +126,8 @@ class TenantRedis(redis.Redis):
|
||||
"srem",
|
||||
"scard",
|
||||
"zadd",
|
||||
"zrange",
|
||||
"zrevrange",
|
||||
"zrangebyscore",
|
||||
"zremrangebyscore",
|
||||
"zscore",
|
||||
|
||||
@@ -11,6 +11,8 @@ All public functions no-op in single-tenant mode (`MULTI_TENANT=False`).
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
@@ -26,6 +28,40 @@ logger = setup_logger()
|
||||
_SET_KEY = "active_tenants"
|
||||
|
||||
|
||||
# --- Prometheus metrics ---
|
||||
|
||||
_active_set_size = Gauge(
|
||||
"onyx_tenant_work_gating_active_set_size",
|
||||
"Current cardinality of the active_tenants sorted set (updated once per "
|
||||
"generator invocation when the gate reads it).",
|
||||
)
|
||||
|
||||
_marked_total = Counter(
|
||||
"onyx_tenant_work_gating_marked_total",
|
||||
"Writes into active_tenants, labelled by caller.",
|
||||
["caller"],
|
||||
)
|
||||
|
||||
_skipped_total = Counter(
|
||||
"onyx_tenant_work_gating_skipped_total",
|
||||
"Per-tenant fanouts skipped by the gate (enforce mode only), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_would_skip_total = Counter(
|
||||
"onyx_tenant_work_gating_would_skip_total",
|
||||
"Per-tenant fanouts that would have been skipped if enforce were on "
|
||||
"(shadow counter), by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
_full_fanout_total = Counter(
|
||||
"onyx_tenant_work_gating_full_fanout_total",
|
||||
"Generator invocations that bypassed the gate for a full fanout cycle, by task.",
|
||||
["task"],
|
||||
)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@@ -54,10 +90,14 @@ def mark_tenant_active(tenant_id: str) -> None:
|
||||
logger.exception(f"mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def maybe_mark_tenant_active(tenant_id: str) -> None:
|
||||
def maybe_mark_tenant_active(tenant_id: str, caller: str = "unknown") -> None:
|
||||
"""Convenience wrapper for writer call sites: records the tenant only
|
||||
when the feature flag is on. Fully defensive — never raises, so a Redis
|
||||
outage or flag-read failure can't abort the calling task."""
|
||||
outage or flag-read failure can't abort the calling task.
|
||||
|
||||
`caller` labels the Prometheus counter so a dashboard can show which
|
||||
consumer is firing the hook most.
|
||||
"""
|
||||
try:
|
||||
# Local import to avoid a module-load cycle: OnyxRuntime imports
|
||||
# onyx.redis.redis_pool, so a top-level import here would wedge on
|
||||
@@ -67,10 +107,44 @@ def maybe_mark_tenant_active(tenant_id: str) -> None:
|
||||
if not OnyxRuntime.get_tenant_work_gating_enabled():
|
||||
return
|
||||
mark_tenant_active(tenant_id)
|
||||
_marked_total.labels(caller=caller).inc()
|
||||
except Exception:
|
||||
logger.exception(f"maybe_mark_tenant_active failed: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
def observe_active_set_size() -> int | None:
|
||||
"""Return `ZCARD active_tenants` and update the Prometheus gauge. Call
|
||||
from the gate generator once per invocation so the dashboard has a
|
||||
live reading.
|
||||
|
||||
Returns `None` on Redis error or in single-tenant mode; callers can
|
||||
tolerate that (gauge simply doesn't update)."""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
try:
|
||||
size = cast(int, _client().zcard(_SET_KEY))
|
||||
_active_set_size.set(size)
|
||||
return size
|
||||
except Exception:
|
||||
logger.exception("observe_active_set_size failed")
|
||||
return None
|
||||
|
||||
|
||||
def record_gate_decision(task_name: str, skipped: bool) -> None:
|
||||
"""Increment skip counters from the gate generator. Called once per
|
||||
tenant that the gate would skip. Always increments the shadow counter;
|
||||
increments the enforced counter only when `skipped=True`."""
|
||||
_would_skip_total.labels(task=task_name).inc()
|
||||
if skipped:
|
||||
_skipped_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def record_full_fanout_cycle(task_name: str) -> None:
|
||||
"""Increment the full-fanout counter. Called once per generator
|
||||
invocation where the gate is bypassed (interval elapsed OR fail-open)."""
|
||||
_full_fanout_total.labels(task=task_name).inc()
|
||||
|
||||
|
||||
def get_active_tenants(ttl_seconds: int) -> set[str] | None:
|
||||
"""Return tenants whose last-seen timestamp is within `ttl_seconds` of
|
||||
now.
|
||||
|
||||
@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -51,6 +53,9 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
|
||||
@@ -128,6 +133,49 @@ class DeleteFileResponse(BaseModel):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
|
||||
"""True if either the filename or the content-type indicates a PDF.
|
||||
|
||||
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
|
||||
``Content-Type: application/octet-stream``), so we also fall back to
|
||||
extension-based detection via ``mimetypes.guess_type`` on the filename.
|
||||
"""
|
||||
if content_type == "application/pdf":
|
||||
return True
|
||||
guessed, _ = mimetypes.guess_type(filename)
|
||||
return guessed == "application/pdf"
|
||||
|
||||
|
||||
def _check_pdf_image_caps(
|
||||
filename: str, content: bytes, content_type: str | None, batch_total: int
|
||||
) -> int:
|
||||
"""Enforce per-file and per-batch embedded-image caps for PDFs.
|
||||
|
||||
Returns the number of embedded images in this file (0 for non-PDFs) so
|
||||
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
|
||||
if either cap is exceeded.
|
||||
"""
|
||||
if not _looks_like_pdf(filename, content_type):
|
||||
return 0
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Short-circuit at the larger cap so we get a useful count for both checks.
|
||||
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"PDF '{filename}' contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting the document into smaller files.",
|
||||
)
|
||||
if batch_total + count > batch_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Upload would exceed the {batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading fewer image-heavy files at once.",
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def _sanitize_path(path: str) -> str:
|
||||
"""Sanitize a file path, removing traversal attempts and normalizing.
|
||||
|
||||
@@ -356,6 +404,7 @@ async def upload_files(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Sanitize the base path
|
||||
@@ -375,6 +424,14 @@ async def upload_files(
|
||||
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
|
||||
)
|
||||
|
||||
# Reject PDFs with an unreasonable per-file or per-batch image count
|
||||
batch_image_total += _check_pdf_image_caps(
|
||||
filename=file.filename or "unnamed",
|
||||
content=content,
|
||||
content_type=file.content_type,
|
||||
batch_total=batch_image_total,
|
||||
)
|
||||
|
||||
# Validate cumulative storage (existing + this upload batch)
|
||||
total_size += file_size
|
||||
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
|
||||
@@ -473,6 +530,7 @@ async def upload_zip(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
|
||||
# Extract zip contents into a subfolder named after the zip file
|
||||
zip_name = api_sanitize_filename(file.filename or "upload")
|
||||
@@ -511,6 +569,36 @@ async def upload_zip(
|
||||
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
|
||||
continue
|
||||
|
||||
# Skip PDFs that would trip the per-file or per-batch image
|
||||
# cap (would OOM the user-file-processing worker). Matches
|
||||
# /upload behavior but uses skip-and-warn to stay consistent
|
||||
# with the zip path's handling of oversized files.
|
||||
zip_file_name = zip_info.filename.split("/")[-1]
|
||||
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
|
||||
if zip_content_type == "application/pdf":
|
||||
image_count = count_pdf_embedded_images(
|
||||
BytesIO(file_content),
|
||||
max(
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
),
|
||||
)
|
||||
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
|
||||
logger.warning(
|
||||
"Skipping '%s' - exceeds %d per-file embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
)
|
||||
continue
|
||||
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
|
||||
logger.warning(
|
||||
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
)
|
||||
continue
|
||||
batch_image_total += image_count
|
||||
|
||||
total_size += file_size
|
||||
|
||||
# Validate cumulative storage
|
||||
|
||||
@@ -9,7 +9,10 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -190,6 +193,11 @@ def categorize_uploaded_files(
|
||||
token_threshold_k * 1000 if token_threshold_k else None
|
||||
) # 0 → None = no limit
|
||||
|
||||
# Running total of embedded images across PDFs in this batch. Once the
|
||||
# aggregate cap is reached, subsequent PDFs in the same upload are
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -252,6 +260,47 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if batch_image_total + count > batch_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Upload would exceed the "
|
||||
f"{batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading "
|
||||
f"fewer image-heavy files at once."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
batch_image_total += count
|
||||
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
@@ -49,6 +50,7 @@ def get_opensearch_retrieval_status(
|
||||
enable_opensearch_retrieval = get_opensearch_retrieval_state(db_session)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,4 +65,5 @@ def set_opensearch_retrieval_status(
|
||||
)
|
||||
return OpenSearchRetrievalStatusResponse(
|
||||
enable_opensearch_retrieval=request.enable_opensearch_retrieval,
|
||||
toggling_retrieval_is_disabled=ONYX_DISABLE_VESPA,
|
||||
)
|
||||
|
||||
@@ -19,3 +19,4 @@ class OpenSearchRetrievalStatusRequest(BaseModel):
|
||||
class OpenSearchRetrievalStatusResponse(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
enable_opensearch_retrieval: bool
|
||||
toggling_retrieval_is_disabled: bool = False
|
||||
|
||||
@@ -395,6 +395,15 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
|
||||
Reads worker status from ``WorkerHeartbeatMonitor`` which listens
|
||||
to the Celery event stream via a single persistent connection.
|
||||
|
||||
TODO: every monitoring pod subscribes to the cluster-wide Celery event
|
||||
stream, so each replica reports health for *all* workers in the cluster,
|
||||
not just itself. Prometheus distinguishes the replicas via the ``instance``
|
||||
label, so this doesn't break scraping, but it means N monitoring replicas
|
||||
do N× the work and may emit slightly inconsistent snapshots of the same
|
||||
cluster. The proper fix is to have each worker expose its own health (or
|
||||
to elect a single monitoring replica as the reporter) rather than
|
||||
broadcasting the full cluster view from every monitoring pod.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = 30.0) -> None:
|
||||
@@ -413,10 +422,16 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers with recent heartbeats",
|
||||
)
|
||||
# Celery hostnames are ``{worker_type}@{nodename}`` (see supervisord.conf).
|
||||
# Emitting only the worker_type as a label causes N replicas of the same
|
||||
# type to collapse into identical timeseries within a single scrape,
|
||||
# which Prometheus rejects as "duplicate sample for timestamp". Split
|
||||
# the pieces into separate labels so each replica is distinct; callers
|
||||
# can still ``sum by (worker_type)`` to recover the old aggregated view.
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
labels=["worker_type", "hostname"],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -424,11 +439,15 @@ class WorkerHealthCollector(_CachedCollector):
|
||||
alive_count = sum(1 for alive in status.values() if alive)
|
||||
active_workers.add_metric([], alive_count)
|
||||
|
||||
for hostname in sorted(status):
|
||||
# Use short name (before @) for single-host deployments,
|
||||
# full hostname when multiple hosts share a worker type.
|
||||
label = hostname.split("@")[0]
|
||||
worker_up.add_metric([label], 1 if status[hostname] else 0)
|
||||
for full_hostname in sorted(status):
|
||||
worker_type, sep, host = full_hostname.partition("@")
|
||||
if not sep:
|
||||
# Hostname didn't contain "@" — fall back to using the
|
||||
# whole string as the hostname with an empty type.
|
||||
worker_type, host = "", full_hostname
|
||||
worker_up.add_metric(
|
||||
[worker_type, host], 1 if status[full_hostname] else 0
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import ONYX_DISABLE_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
|
||||
@@ -126,10 +127,11 @@ def setup_onyx(
|
||||
"DISABLE_VECTOR_DB is set — skipping document index setup and embedding model warm-up."
|
||||
)
|
||||
else:
|
||||
# Ensure Vespa is setup correctly, this step is relatively near the end
|
||||
# because Vespa takes a bit of time to start up
|
||||
# Ensure the document indices are setup correctly. This step is
|
||||
# relatively near the end because Vespa takes a bit of time to start up.
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
# This flow is for setting up the document index so we get all indices here.
|
||||
# This flow is for setting up the document index so we get all indices
|
||||
# here.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings,
|
||||
secondary_search_settings,
|
||||
@@ -335,7 +337,7 @@ def setup_multitenant_onyx() -> None:
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
if not MANAGED_VESPA and not ONYX_DISABLE_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
|
||||
@@ -208,12 +208,6 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -48,7 +47,7 @@ PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -142,8 +141,11 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
"Optional file_ids of existing images to edit or use as reference;"
|
||||
" the first is the primary edit source."
|
||||
" Get file_ids from `[attached image — file_id: <id>]` tags on"
|
||||
" user-attached images or from prior generate_image tool responses."
|
||||
" Omit for a fresh, unrelated generation."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
@@ -254,41 +256,31 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
if raw_reference_ids is None:
|
||||
# No references requested — plain generation.
|
||||
return []
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
|
||||
# Deduplicate while preserving order (first occurrence wins, so the
|
||||
# LLM's intended "primary edit source" stays at index 0).
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
for file_id in raw_reference_ids:
|
||||
file_id = file_id.strip()
|
||||
if not file_id or file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
@@ -302,14 +294,14 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
f"Reference images requested but provider '{self.provider}' does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"This image provider does not support editing from existing images. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids[:max_reference_images]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
@@ -358,7 +350,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -373,7 +365,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _extract_image_file_ids_from_tool_response_message(
|
||||
message: str,
|
||||
) -> list[str]:
|
||||
try:
|
||||
parsed_message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
parsed_items: list[Any] = (
|
||||
parsed_message if isinstance(parsed_message, list) else [parsed_message]
|
||||
)
|
||||
file_ids: list[str] = []
|
||||
for item in parsed_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
file_id = item.get("file_id")
|
||||
if isinstance(file_id, str):
|
||||
file_ids.append(file_id)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
def _extract_recent_generated_image_file_ids(
|
||||
message_history: list[ChatMessageSimple],
|
||||
) -> list[str]:
|
||||
tool_name_by_tool_call_id: dict[str, str] = {}
|
||||
recent_image_file_ids: list[str] = []
|
||||
seen_file_ids: set[str] = set()
|
||||
|
||||
for message in message_history:
|
||||
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
|
||||
continue
|
||||
|
||||
if (
|
||||
message.message_type != MessageType.TOOL_CALL_RESPONSE
|
||||
or not message.tool_call_id
|
||||
):
|
||||
continue
|
||||
|
||||
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
|
||||
if tool_name != ImageGenerationTool.NAME:
|
||||
continue
|
||||
|
||||
for file_id in _extract_image_file_ids_from_tool_response_message(
|
||||
message.message
|
||||
):
|
||||
if file_id in seen_file_ids:
|
||||
continue
|
||||
seen_file_ids.add(file_id)
|
||||
recent_image_file_ids.append(file_id)
|
||||
|
||||
return recent_image_file_ids
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
@@ -386,9 +324,6 @@ def run_tool_calls(
|
||||
url_to_citation: dict[str, int] = {
|
||||
url: citation_num for citation_num, url in citation_mapping.items()
|
||||
}
|
||||
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Prepare all tool calls with their override_kwargs
|
||||
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
|
||||
@@ -405,7 +340,6 @@ def run_tool_calls(
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| PythonToolOverrideKwargs
|
||||
| ImageGenerationToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
@@ -454,10 +388,6 @@ def run_tool_calls(
|
||||
override_kwargs = PythonToolOverrideKwargs(
|
||||
chat_files=chat_files or [],
|
||||
)
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
override_kwargs = ImageGenerationToolOverrideKwargs(
|
||||
recent_generated_image_file_ids=recent_generated_image_file_ids
|
||||
)
|
||||
elif isinstance(tool, MemoryTool):
|
||||
override_kwargs = MemoryToolOverrideKwargs(
|
||||
user_name=(
|
||||
|
||||
@@ -34,6 +34,7 @@ R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
_T = TypeVar("_T") # Default type
|
||||
_MISSING: object = object()
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
@@ -117,10 +118,10 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
def pop(self, key: KT, default: Any = _MISSING) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
if default is _MISSING:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
# zeep
|
||||
authlib==1.6.9
|
||||
authlib==1.6.11
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
babel==2.17.0
|
||||
@@ -214,7 +214,9 @@ distro==1.9.0
|
||||
dnspython==2.8.0
|
||||
# via email-validator
|
||||
docstring-parser==0.17.0
|
||||
# via cyclopts
|
||||
# via
|
||||
# cyclopts
|
||||
# google-cloud-aiplatform
|
||||
docutils==0.22.3
|
||||
# via rich-rst
|
||||
dropbox==12.0.2
|
||||
@@ -270,7 +272,13 @@ gitdb==4.0.12
|
||||
gitpython==3.1.45
|
||||
# via braintrust
|
||||
google-api-core==2.28.1
|
||||
# via google-api-python-client
|
||||
# via
|
||||
# google-api-python-client
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-api-python-client==2.86.0
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
@@ -278,21 +286,61 @@ google-auth==2.48.0
|
||||
# google-api-python-client
|
||||
# google-auth-httplib2
|
||||
# google-auth-oauthlib
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-auth-httplib2==0.1.0
|
||||
# via google-api-python-client
|
||||
google-auth-oauthlib==1.0.0
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
greenlet==3.2.4
|
||||
# via
|
||||
# playwright
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -443,7 +491,7 @@ magika==0.6.3
|
||||
# via markitdown
|
||||
makefun==1.16.0
|
||||
# via fastapi-users
|
||||
mako==1.2.4
|
||||
mako==1.3.11
|
||||
# via alembic
|
||||
mammoth==1.11.0
|
||||
# via markitdown
|
||||
@@ -559,6 +607,8 @@ packaging==24.2
|
||||
# dask
|
||||
# distributed
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# kombu
|
||||
@@ -605,12 +655,19 @@ propcache==0.4.1
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via google-api-core
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
@@ -643,6 +700,7 @@ pydantic==2.11.7
|
||||
# exa-py
|
||||
# fastapi
|
||||
# fastmcp
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# langchain-core
|
||||
# langfuse
|
||||
@@ -679,7 +737,7 @@ pynacl==1.6.2
|
||||
pypandoc-binary==1.16.2
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.10.0
|
||||
pypdf==6.10.2
|
||||
# via unstructured-client
|
||||
pyperclip==1.11.0
|
||||
# via fastmcp
|
||||
@@ -701,6 +759,7 @@ python-dateutil==2.8.2
|
||||
# botocore
|
||||
# celery
|
||||
# dateparser
|
||||
# google-cloud-bigquery
|
||||
# htmldate
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
@@ -779,6 +838,8 @@ requests==2.33.0
|
||||
# dropbox
|
||||
# exa-py
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# jira
|
||||
@@ -951,7 +1012,9 @@ typing-extensions==4.15.0
|
||||
# exa-py
|
||||
# exceptiongroup
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
|
||||
@@ -114,6 +114,8 @@ distlib==0.4.0
|
||||
# via virtualenv
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
execnet==2.1.2
|
||||
@@ -141,14 +143,65 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -218,7 +271,7 @@ kubernetes==31.0.0
|
||||
# via onyx
|
||||
litellm==1.81.6
|
||||
# via onyx
|
||||
mako==1.2.4
|
||||
mako==1.3.11
|
||||
# via alembic
|
||||
manygo==0.2.0
|
||||
markdown-it-py==4.0.0
|
||||
@@ -267,6 +320,8 @@ openapi-generator-cli==7.17.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# hatchling
|
||||
# huggingface-hub
|
||||
# ipykernel
|
||||
@@ -307,6 +362,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via ipykernel
|
||||
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
||||
@@ -328,6 +397,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -364,6 +434,7 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# jupyter-client
|
||||
# kubernetes
|
||||
# matplotlib
|
||||
@@ -398,6 +469,9 @@ reorder-python-imports-black==3.14.0
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
@@ -498,7 +572,9 @@ typing-extensions==4.15.0
|
||||
# celery-types
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mcp
|
||||
|
||||
@@ -87,6 +87,8 @@ discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
fastapi==0.133.1
|
||||
@@ -103,12 +105,63 @@ frozenlist==1.8.0
|
||||
# aiosignal
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -184,7 +237,10 @@ openai==2.14.0
|
||||
# litellm
|
||||
# onyx
|
||||
packaging==24.2
|
||||
# via huggingface-hub
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
parameterized==0.9.0
|
||||
# via cohere
|
||||
posthog==3.7.4
|
||||
@@ -198,6 +254,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.3
|
||||
@@ -213,6 +283,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -231,6 +302,7 @@ python-dateutil==2.8.2
|
||||
# via
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
# posthog
|
||||
python-dotenv==1.1.1
|
||||
@@ -254,6 +326,9 @@ regex==2025.11.3
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
# posthog
|
||||
@@ -318,7 +393,9 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
@@ -102,6 +102,8 @@ discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
# via google-cloud-aiplatform
|
||||
durationpy==0.10
|
||||
# via kubernetes
|
||||
einops==0.8.1
|
||||
@@ -125,12 +127,63 @@ fsspec==2025.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
google-api-core==2.28.1
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
google-auth==2.48.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-core
|
||||
# google-cloud-resource-manager
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
google-cloud-aiplatform==1.133.0
|
||||
# via litellm
|
||||
google-cloud-bigquery==3.41.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-core==2.5.1
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
google-cloud-resource-manager==1.17.0
|
||||
# via google-cloud-aiplatform
|
||||
google-cloud-storage==3.10.1
|
||||
# via google-cloud-aiplatform
|
||||
google-crc32c==1.8.0
|
||||
# via
|
||||
# google-cloud-storage
|
||||
# google-resumable-media
|
||||
google-genai==1.52.0
|
||||
# via onyx
|
||||
# via
|
||||
# google-cloud-aiplatform
|
||||
# onyx
|
||||
google-resumable-media==2.8.2
|
||||
# via
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
googleapis-common-protos==1.72.0
|
||||
# via
|
||||
# google-api-core
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.4
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.80.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.80.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -265,6 +318,8 @@ openai==2.14.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-bigquery
|
||||
# huggingface-hub
|
||||
# kombu
|
||||
# transformers
|
||||
@@ -282,6 +337,20 @@ propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
proto-plus==1.26.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
psutil==7.1.3
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
@@ -299,6 +368,7 @@ pydantic==2.11.7
|
||||
# agent-client-protocol
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# litellm
|
||||
# mcp
|
||||
@@ -318,6 +388,7 @@ python-dateutil==2.8.2
|
||||
# aiobotocore
|
||||
# botocore
|
||||
# celery
|
||||
# google-cloud-bigquery
|
||||
# kubernetes
|
||||
python-dotenv==1.1.1
|
||||
# via
|
||||
@@ -344,6 +415,9 @@ regex==2025.11.3
|
||||
requests==2.33.0
|
||||
# via
|
||||
# cohere
|
||||
# google-api-core
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
@@ -437,7 +511,9 @@ typing-extensions==4.15.0
|
||||
# anyio
|
||||
# cohere
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# mcp
|
||||
# openai
|
||||
|
||||
@@ -46,7 +46,7 @@ stop_and_remove_containers
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
if [[ -n "$POSTGRES_VOLUME" ]]; then
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d -v "$POSTGRES_VOLUME":/var/lib/postgresql/data postgres -c max_connections=250
|
||||
else
|
||||
docker run -p 5432:5432 --name onyx_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
|
||||
fi
|
||||
@@ -54,7 +54,7 @@ fi
|
||||
# Start the Vespa container with optional volume
|
||||
echo "Starting Vespa container..."
|
||||
if [[ -n "$VESPA_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v $VESPA_VOLUME:/opt/vespa/var vespaengine/vespa:8
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 -v "$VESPA_VOLUME":/opt/vespa/var vespaengine/vespa:8
|
||||
else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
@@ -85,7 +85,7 @@ docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-en
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v $REDIS_VOLUME:/data redis
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 -v "$REDIS_VOLUME":/data redis
|
||||
else
|
||||
docker run --detach --name onyx_redis --publish 6379:6379 redis
|
||||
fi
|
||||
@@ -93,7 +93,7 @@ fi
|
||||
# Start the MinIO container with optional volume
|
||||
echo "Starting MinIO container..."
|
||||
if [[ -n "$MINIO_VOLUME" ]]; then
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v $MINIO_VOLUME:/data minio/minio server /data --console-address ":9001"
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin -v "$MINIO_VOLUME":/data minio/minio server /data --console-address ":9001"
|
||||
else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
@@ -111,6 +111,7 @@ sleep 1
|
||||
|
||||
# Alembic should be configured in the virtualenv for this repo
|
||||
if [[ -f "../.venv/bin/activate" ]]; then
|
||||
# shellcheck source=/dev/null
|
||||
source ../.venv/bin/activate
|
||||
else
|
||||
echo "Warning: Python virtual environment not found at .venv/bin/activate; alembic may not work."
|
||||
|
||||
@@ -9,8 +9,10 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -111,15 +113,18 @@ def test_blob_s3_connector(
|
||||
|
||||
for doc in all_docs:
|
||||
section = doc.sections[0]
|
||||
assert isinstance(section, TextSection)
|
||||
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
if is_tabular_file(doc.semantic_identifier):
|
||||
assert isinstance(section, TabularSection)
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
# unknown extension
|
||||
assert len(section.text) == 0
|
||||
assert isinstance(section, TextSection)
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
else:
|
||||
assert len(section.text) == 0
|
||||
|
||||
|
||||
@patch(
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -32,18 +31,20 @@ def test_gong_basic(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
gong_connector: GongConnector,
|
||||
) -> None:
|
||||
doc_batch_generator = gong_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
checkpoint = gong_connector.build_dummy_checkpoint()
|
||||
|
||||
docs: list[Document] = []
|
||||
for doc in doc_batch:
|
||||
if not isinstance(doc, HierarchyNode):
|
||||
docs.append(doc)
|
||||
while checkpoint.has_more:
|
||||
generator = gong_connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 2
|
||||
|
||||
assert docs[0].semantic_identifier == "test with chris"
|
||||
assert docs[1].semantic_identifier == "Testing Gong"
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Tests for `cloud_beat_task_generator`'s tenant work-gating logic.
|
||||
|
||||
Exercises the gate-read path end-to-end against real Redis. The Celery
|
||||
`.app.send_task` is mocked so we can count dispatches without actually
|
||||
sending messages.
|
||||
|
||||
Requires a running Redis instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest \
|
||||
backend/tests/external_dependency_unit/tenant_work_gating/test_gate_generator.py
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.background.celery.tasks.cloud import tasks as cloud_tasks
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis import redis_tenant_work_gating as twg
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_tenant_work_gating import _SET_KEY
|
||||
from onyx.redis.redis_tenant_work_gating import mark_tenant_active
|
||||
|
||||
|
||||
_TENANT_A = "tenant_aaaa0000-0000-0000-0000-000000000001"
|
||||
_TENANT_B = "tenant_bbbb0000-0000-0000-0000-000000000002"
|
||||
_TENANT_C = "tenant_cccc0000-0000-0000-0000-000000000003"
|
||||
_ALL_TEST_TENANTS = [_TENANT_A, _TENANT_B, _TENANT_C]
|
||||
_FANOUT_KEY_PREFIX = cloud_tasks._FULL_FANOUT_TIMESTAMP_KEY_PREFIX
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _multi_tenant_true() -> Generator[None, None, None]:
|
||||
with patch.object(twg, "MULTI_TENANT", True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_redis() -> Generator[None, None, None]:
|
||||
"""Clear the active set AND the per-task full-fanout timestamp so each
|
||||
test starts fresh."""
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
yield
|
||||
r.delete(_SET_KEY)
|
||||
r.delete(f"{_FANOUT_KEY_PREFIX}:test_task")
|
||||
r.delete("runtime:tenant_work_gating:enabled")
|
||||
r.delete("runtime:tenant_work_gating:enforce")
|
||||
|
||||
|
||||
def _invoke_generator(
|
||||
*,
|
||||
work_gated: bool,
|
||||
enabled: bool,
|
||||
enforce: bool,
|
||||
tenant_ids: list[str],
|
||||
full_fanout_interval_seconds: int = 1200,
|
||||
ttl_seconds: int = 1800,
|
||||
) -> MagicMock:
|
||||
"""Helper: call the generator with runtime flags fixed and the Celery
|
||||
app mocked. Returns the mock so callers can assert on send_task calls."""
|
||||
mock_app = MagicMock()
|
||||
# The task binds `self` = the task itself when invoked via `.run()`;
|
||||
# patch its `.app` so `self.app.send_task` routes to our mock.
|
||||
with (
|
||||
patch.object(cloud_tasks.cloud_beat_task_generator, "app", mock_app),
|
||||
patch.object(cloud_tasks, "get_all_tenant_ids", return_value=list(tenant_ids)),
|
||||
patch.object(cloud_tasks, "get_gated_tenants", return_value=set()),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enabled",
|
||||
return_value=enabled,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_enforce",
|
||||
return_value=enforce,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_full_fanout_interval_seconds",
|
||||
return_value=full_fanout_interval_seconds,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.runtime.onyx_runtime.OnyxRuntime.get_tenant_work_gating_ttl_seconds",
|
||||
return_value=ttl_seconds,
|
||||
),
|
||||
):
|
||||
cloud_tasks.cloud_beat_task_generator.run(
|
||||
task_name="test_task",
|
||||
work_gated=work_gated,
|
||||
)
|
||||
return mock_app
|
||||
|
||||
|
||||
def _dispatched_tenants(mock_app: MagicMock) -> list[str]:
|
||||
"""Pull tenant_ids out of each send_task call for assertion."""
|
||||
return [c.kwargs["kwargs"]["tenant_id"] for c in mock_app.send_task.call_args_list]
|
||||
|
||||
|
||||
def _seed_recent_full_fanout_timestamp() -> None:
|
||||
"""Pre-seed the per-task timestamp so the interval-elapsed branch
|
||||
reports False, i.e. the gate enforces normally instead of going into
|
||||
full-fanout on first invocation."""
|
||||
import time as _t
|
||||
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
r.set(f"{_FANOUT_KEY_PREFIX}:test_task", str(int(_t.time() * 1000)))
|
||||
|
||||
|
||||
def test_enforce_skips_unmarked_tenants() -> None:
|
||||
"""With enable+enforce on (interval NOT elapsed), only tenants in the
|
||||
active set get dispatched."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert dispatched == [_TENANT_A]
|
||||
|
||||
|
||||
def test_shadow_mode_dispatches_all_tenants() -> None:
|
||||
"""enabled=True, enforce=False: gate computes skip but still dispatches."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=False,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_full_fanout_cycle_dispatches_all_tenants() -> None:
|
||||
"""First invocation (no prior timestamp → interval considered elapsed)
|
||||
counts as full-fanout; every tenant gets dispatched even under enforce."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_redis_unavailable_fails_open() -> None:
|
||||
"""When `get_active_tenants` returns None (simulated Redis outage) the
|
||||
gate treats the invocation as full-fanout and dispatches everyone —
|
||||
even when the interval hasn't elapsed and enforce is on."""
|
||||
mark_tenant_active(_TENANT_A)
|
||||
_seed_recent_full_fanout_timestamp()
|
||||
|
||||
with patch.object(cloud_tasks, "get_active_tenants", return_value=None):
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
full_fanout_interval_seconds=3600,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_work_gated_false_bypasses_gate_entirely() -> None:
|
||||
"""Beat templates that don't opt in (`work_gated=False`) never consult
|
||||
the set — no matter the flag state."""
|
||||
# Even with enforce on and nothing in the set, all tenants dispatch.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=False,
|
||||
enabled=True,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
|
||||
|
||||
def test_gate_disabled_dispatches_everyone_regardless_of_enforce() -> None:
|
||||
"""enabled=False means the gate isn't computed — dispatch is unchanged."""
|
||||
# Intentionally don't add anyone to the set.
|
||||
mock_app = _invoke_generator(
|
||||
work_gated=True,
|
||||
enabled=False,
|
||||
enforce=True,
|
||||
tenant_ids=_ALL_TEST_TENANTS,
|
||||
)
|
||||
|
||||
dispatched = _dispatched_tenants(mock_app)
|
||||
assert set(dispatched) == set(_ALL_TEST_TENANTS)
|
||||
@@ -16,12 +16,14 @@ from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.constants import MCP_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.pat import PATManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@@ -34,6 +36,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
# Constants
|
||||
MCP_SEARCH_TOOL = "search_indexed_documents"
|
||||
INDEXED_SOURCES_RESOURCE_URI = "resource://indexed_sources"
|
||||
DOCUMENT_SETS_RESOURCE_URI = "resource://document_sets"
|
||||
DEFAULT_SEARCH_LIMIT = 5
|
||||
STREAMABLE_HTTP_URL = f"{MCP_SERVER_URL.rstrip('/')}/?transportType=streamable-http"
|
||||
|
||||
@@ -73,19 +76,22 @@ def _extract_tool_payload(result: CallToolResult) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _call_search_tool(
|
||||
headers: dict[str, str], query: str, limit: int = DEFAULT_SEARCH_LIMIT
|
||||
headers: dict[str, str],
|
||||
query: str,
|
||||
limit: int = DEFAULT_SEARCH_LIMIT,
|
||||
document_set_names: list[str] | None = None,
|
||||
) -> CallToolResult:
|
||||
"""Call the search_indexed_documents tool via MCP."""
|
||||
|
||||
async def _action(session: ClientSession) -> CallToolResult:
|
||||
await session.initialize()
|
||||
return await session.call_tool(
|
||||
MCP_SEARCH_TOOL,
|
||||
{
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
arguments: dict[str, Any] = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
}
|
||||
if document_set_names is not None:
|
||||
arguments["document_set_names"] = document_set_names
|
||||
return await session.call_tool(MCP_SEARCH_TOOL, arguments)
|
||||
|
||||
return _run_with_mcp_session(headers, _action)
|
||||
|
||||
@@ -238,3 +244,106 @@ def test_mcp_search_respects_acl_filters(
|
||||
blocked_payload = _extract_tool_payload(blocked_result)
|
||||
assert blocked_payload["total_results"] == 0
|
||||
assert blocked_payload["documents"] == []
|
||||
|
||||
|
||||
def test_mcp_search_filters_by_document_set(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Passing document_set_names should scope results to the named set."""
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
api_key = APIKeyManager.create(user_performing_action=admin_user)
|
||||
cc_pair_in_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_out_of_set = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
shared_phrase = "document-set-filter-shared-phrase"
|
||||
in_set_content = f"{shared_phrase} inside curated set"
|
||||
out_of_set_content = f"{shared_phrase} outside curated set"
|
||||
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_in_set,
|
||||
content=in_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
_seed_document_and_wait_for_indexing(
|
||||
cc_pair=cc_pair_out_of_set,
|
||||
content=out_of_set_content,
|
||||
api_key=api_key,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_set = DocumentSetManager.create(
|
||||
cc_pair_ids=[cc_pair_in_set.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
document_sets_to_check=[doc_set],
|
||||
)
|
||||
|
||||
headers = _auth_headers(admin_user, name="mcp-doc-set-filter")
|
||||
|
||||
# The document_sets resource should surface the newly created set so MCP
|
||||
# clients can discover which values to pass to document_set_names.
|
||||
async def _list_resources(session: ClientSession) -> Any:
|
||||
await session.initialize()
|
||||
resources = await session.list_resources()
|
||||
contents = await session.read_resource(AnyUrl(DOCUMENT_SETS_RESOURCE_URI))
|
||||
return resources, contents
|
||||
|
||||
resources_result, doc_sets_contents = _run_with_mcp_session(
|
||||
headers, _list_resources
|
||||
)
|
||||
resource_uris = {str(resource.uri) for resource in resources_result.resources}
|
||||
assert DOCUMENT_SETS_RESOURCE_URI in resource_uris
|
||||
doc_sets_payload = json.loads(doc_sets_contents.contents[0].text)
|
||||
exposed_names = {entry["name"] for entry in doc_sets_payload}
|
||||
assert doc_set.name in exposed_names
|
||||
|
||||
# Without the filter both documents are visible.
|
||||
unfiltered_payload = _extract_tool_payload(
|
||||
_call_search_tool(headers, shared_phrase, limit=10)
|
||||
)
|
||||
unfiltered_contents = [
|
||||
doc.get("content") or "" for doc in unfiltered_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in unfiltered_contents)
|
||||
assert any(out_of_set_content in content for content in unfiltered_contents)
|
||||
|
||||
# With the document set filter only the in-set document is returned.
|
||||
filtered_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[doc_set.name],
|
||||
)
|
||||
)
|
||||
filtered_contents = [
|
||||
doc.get("content") or "" for doc in filtered_payload["documents"]
|
||||
]
|
||||
assert filtered_payload["total_results"] >= 1
|
||||
assert any(in_set_content in content for content in filtered_contents)
|
||||
assert all(out_of_set_content not in content for content in filtered_contents)
|
||||
|
||||
# An empty document_set_names should behave like "no filter" (normalized
|
||||
# to None), not "match zero sets".
|
||||
empty_list_payload = _extract_tool_payload(
|
||||
_call_search_tool(
|
||||
headers,
|
||||
shared_phrase,
|
||||
limit=10,
|
||||
document_set_names=[],
|
||||
)
|
||||
)
|
||||
empty_list_contents = [
|
||||
doc.get("content") or "" for doc in empty_list_payload["documents"]
|
||||
]
|
||||
assert any(in_set_content in content for content in empty_list_contents)
|
||||
assert any(out_of_set_content in content for content in empty_list_contents)
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
|
||||
|
||||
def test_time_str_to_utc() -> None:
|
||||
str_to_dt = {
|
||||
"Tue, 5 Oct 2021 09:38:25 GMT": datetime.datetime(
|
||||
2021, 10, 5, 9, 38, 25, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Sat, 24 Jul 2021 09:21:20 +0000 (UTC)": datetime.datetime(
|
||||
2021, 7, 24, 9, 21, 20, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Thu, 29 Jul 2021 04:20:37 -0400 (EDT)": datetime.datetime(
|
||||
2021, 7, 29, 8, 20, 37, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"30 Jun 2023 18:45:01 +0300": datetime.datetime(
|
||||
2023, 6, 30, 15, 45, 1, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"22 Mar 2020 20:12:18 +0000 (GMT)": datetime.datetime(
|
||||
2020, 3, 22, 20, 12, 18, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Date: Wed, 27 Aug 2025 11:40:00 +0200": datetime.datetime(
|
||||
2025, 8, 27, 9, 40, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
for strptime, expected_datetime in str_to_dt.items():
|
||||
assert time_str_to_utc(strptime) == expected_datetime
|
||||
|
||||
|
||||
def test_time_str_to_utc_recovers_from_concatenated_headers() -> None:
|
||||
# TZ is dropped during recovery, so the expected result is UTC rather
|
||||
# than the original offset.
|
||||
assert time_str_to_utc(
|
||||
'Sat, 3 Nov 2007 14:33:28 -0200To: "jason" <jason@example.net>'
|
||||
) == datetime.datetime(2007, 11, 3, 14, 33, 28, tzinfo=datetime.timezone.utc)
|
||||
|
||||
assert time_str_to_utc(
|
||||
"Fri, 20 Feb 2015 10:30:00 +0500Cc: someone@example.com"
|
||||
) == datetime.datetime(2015, 2, 20, 10, 30, 0, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def test_time_str_to_utc_raises_on_impossible_dates() -> None:
|
||||
for bad in (
|
||||
"Wed, 33 Sep 2007 13:42:59 +0100",
|
||||
"Thu, 11 Oct 2007 31:50:55 +0900",
|
||||
"not a date at all",
|
||||
"",
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
time_str_to_utc(bad)
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
@@ -8,7 +9,6 @@ from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.gmail.connector import _build_time_range_query
|
||||
from onyx.connectors.gmail.connector import GmailCheckpoint
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
@@ -51,29 +51,43 @@ def test_build_time_range_query() -> None:
|
||||
assert query is None
|
||||
|
||||
|
||||
def test_time_str_to_utc() -> None:
|
||||
str_to_dt = {
|
||||
"Tue, 5 Oct 2021 09:38:25 GMT": datetime.datetime(
|
||||
2021, 10, 5, 9, 38, 25, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Sat, 24 Jul 2021 09:21:20 +0000 (UTC)": datetime.datetime(
|
||||
2021, 7, 24, 9, 21, 20, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Thu, 29 Jul 2021 04:20:37 -0400 (EDT)": datetime.datetime(
|
||||
2021, 7, 29, 8, 20, 37, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"30 Jun 2023 18:45:01 +0300": datetime.datetime(
|
||||
2023, 6, 30, 15, 45, 1, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"22 Mar 2020 20:12:18 +0000 (GMT)": datetime.datetime(
|
||||
2020, 3, 22, 20, 12, 18, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
"Date: Wed, 27 Aug 2025 11:40:00 +0200": datetime.datetime(
|
||||
2025, 8, 27, 9, 40, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
for strptime, expected_datetime in str_to_dt.items():
|
||||
assert time_str_to_utc(strptime) == expected_datetime
|
||||
def _thread_with_date(date_header: str | None) -> dict[str, Any]:
|
||||
"""Load the fixture thread and replace (or strip, if None) its Date header."""
|
||||
json_path = os.path.join(os.path.dirname(__file__), "thread.json")
|
||||
with open(json_path, "r") as f:
|
||||
thread = cast(dict[str, Any], json.load(f))
|
||||
thread = copy.deepcopy(thread)
|
||||
|
||||
for message in thread["messages"]:
|
||||
headers: list[dict[str, str]] = message["payload"]["headers"]
|
||||
if date_header is None:
|
||||
message["payload"]["headers"] = [
|
||||
h for h in headers if h.get("name") != "Date"
|
||||
]
|
||||
continue
|
||||
|
||||
replaced = False
|
||||
for header in headers:
|
||||
if header.get("name") == "Date":
|
||||
header["value"] = date_header
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
headers.append({"name": "Date", "value": date_header})
|
||||
|
||||
return thread
|
||||
|
||||
|
||||
def test_thread_to_document_skips_unparseable_dates() -> None:
|
||||
for bad_date in (
|
||||
"Wed, 33 Sep 2007 13:42:59 +0100",
|
||||
"Thu, 11 Oct 2007 31:50:55 +0900",
|
||||
"total garbage not even close to a date",
|
||||
):
|
||||
doc = thread_to_document(_thread_with_date(bad_date), "admin@example.com")
|
||||
assert isinstance(doc, Document), f"failed for {bad_date!r}"
|
||||
assert doc.doc_updated_at is None
|
||||
assert doc.id == "192edefb315737c3"
|
||||
|
||||
|
||||
def test_gmail_checkpoint_progression() -> None:
|
||||
|
||||
0
backend/tests/unit/onyx/connectors/gong/__init__.py
Normal file
0
backend/tests/unit/onyx/connectors/gong/__init__.py
Normal file
483
backend/tests/unit/onyx/connectors/gong/test_gong_checkpoint.py
Normal file
483
backend/tests/unit/onyx/connectors/gong/test_gong_checkpoint.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.gong.connector import GongConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
def _make_transcript(call_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"callId": call_id,
|
||||
"transcript": [
|
||||
{
|
||||
"speakerId": "speaker1",
|
||||
"sentences": [{"text": "Hello world"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _make_call_detail(call_id: str, title: str) -> dict[str, Any]:
|
||||
return {
|
||||
"metaData": {
|
||||
"id": call_id,
|
||||
"started": "2026-01-15T10:00:00Z",
|
||||
"title": title,
|
||||
"purpose": "Test call",
|
||||
"url": f"https://app.gong.io/call?id={call_id}",
|
||||
"system": "test-system",
|
||||
},
|
||||
"parties": [
|
||||
{
|
||||
"speakerId": "speaker1",
|
||||
"name": "Alice",
|
||||
"emailAddress": "alice@test.com",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connector() -> GongConnector:
|
||||
connector = GongConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"gong_access_key": "test-key",
|
||||
"gong_access_key_secret": "test-secret",
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
class TestGongConnectorCheckpoint:
|
||||
def test_build_dummy_checkpoint(self, connector: GongConnector) -> None:
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
assert checkpoint.has_more is True
|
||||
assert checkpoint.workspace_ids is None
|
||||
assert checkpoint.workspace_index == 0
|
||||
assert checkpoint.cursor is None
|
||||
|
||||
def test_validate_checkpoint_json(self, connector: GongConnector) -> None:
|
||||
original = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=["ws1", None],
|
||||
workspace_index=1,
|
||||
cursor="abc123",
|
||||
)
|
||||
json_str = original.model_dump_json()
|
||||
restored = connector.validate_checkpoint_json(json_str)
|
||||
assert restored == original
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_first_call_resolves_workspaces(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""First checkpoint call should resolve workspaces and return without fetching."""
|
||||
# No workspaces configured — should resolve to [None]
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
|
||||
# Should return immediately (no yields)
|
||||
with pytest.raises(StopIteration) as exc_info:
|
||||
next(generator)
|
||||
|
||||
new_checkpoint = exc_info.value.value
|
||||
assert new_checkpoint.workspace_ids == [None]
|
||||
assert new_checkpoint.has_more is True
|
||||
assert new_checkpoint.workspace_index == 0
|
||||
|
||||
# No API calls should have been made for workspace resolution
|
||||
# when no workspaces are configured
|
||||
mock_request.assert_not_called()
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_single_page_no_cursor(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Single page of transcripts with no pagination cursor."""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call1")],
|
||||
"records": {},
|
||||
}
|
||||
|
||||
details_response = MagicMock()
|
||||
details_response.status_code = 200
|
||||
details_response.json.return_value = {
|
||||
"calls": [_make_call_detail("call1", "Test Call")]
|
||||
}
|
||||
|
||||
mock_request.side_effect = [transcript_response, details_response]
|
||||
|
||||
# Start from a checkpoint that already has workspaces resolved
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failures.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].semantic_identifier == "Test Call"
|
||||
assert len(failures) == 0
|
||||
assert checkpoint.has_more is False
|
||||
assert checkpoint.workspace_index == 1
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_multi_page_with_cursor(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Two pages of transcripts — cursor advances between checkpoint calls."""
|
||||
# Page 1: returns cursor
|
||||
page1_response = MagicMock()
|
||||
page1_response.status_code = 200
|
||||
page1_response.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call1")],
|
||||
"records": {"cursor": "page2cursor"},
|
||||
}
|
||||
|
||||
details1_response = MagicMock()
|
||||
details1_response.status_code = 200
|
||||
details1_response.json.return_value = {
|
||||
"calls": [_make_call_detail("call1", "Call One")]
|
||||
}
|
||||
|
||||
# Page 2: no cursor (done)
|
||||
page2_response = MagicMock()
|
||||
page2_response.status_code = 200
|
||||
page2_response.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call2")],
|
||||
"records": {},
|
||||
}
|
||||
|
||||
details2_response = MagicMock()
|
||||
details2_response.status_code = 200
|
||||
details2_response.json.return_value = {
|
||||
"calls": [_make_call_detail("call2", "Call Two")]
|
||||
}
|
||||
|
||||
mock_request.side_effect = [
|
||||
page1_response,
|
||||
details1_response,
|
||||
page2_response,
|
||||
details2_response,
|
||||
]
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
all_docs: list[Document] = []
|
||||
|
||||
# First checkpoint call — page 1
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
all_docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(all_docs) == 1
|
||||
assert checkpoint.cursor == "page2cursor"
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
# Second checkpoint call — page 2
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
all_docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(all_docs) == 2
|
||||
assert all_docs[0].semantic_identifier == "Call One"
|
||||
assert all_docs[1].semantic_identifier == "Call Two"
|
||||
assert checkpoint.has_more is False
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_missing_call_details_yields_failure(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""When call details are missing after retries, yield ConnectorFailure."""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call1")],
|
||||
"records": {},
|
||||
}
|
||||
|
||||
# Return empty call details every time (simulating the race condition)
|
||||
empty_details = MagicMock()
|
||||
empty_details.status_code = 200
|
||||
empty_details.json.return_value = {"calls": []}
|
||||
|
||||
mock_request.side_effect = [transcript_response] + [
|
||||
empty_details
|
||||
] * GongConnector.MAX_CALL_DETAILS_ATTEMPTS
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
failures: list[ConnectorFailure] = []
|
||||
docs: list[Document] = []
|
||||
|
||||
with patch("onyx.connectors.gong.connector.time.sleep"):
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, ConnectorFailure):
|
||||
failures.append(item)
|
||||
elif isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 0
|
||||
assert len(failures) == 1
|
||||
assert failures[0].failed_document is not None
|
||||
assert failures[0].failed_document.document_id == "call1"
|
||||
assert checkpoint.has_more is False
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_multi_workspace_iteration(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Checkpoint iterates through multiple workspaces."""
|
||||
# Workspace 1: one call
|
||||
ws1_transcript = MagicMock()
|
||||
ws1_transcript.status_code = 200
|
||||
ws1_transcript.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call_ws1")],
|
||||
"records": {},
|
||||
}
|
||||
ws1_details = MagicMock()
|
||||
ws1_details.status_code = 200
|
||||
ws1_details.json.return_value = {
|
||||
"calls": [_make_call_detail("call_ws1", "WS1 Call")]
|
||||
}
|
||||
|
||||
# Workspace 2: one call
|
||||
ws2_transcript = MagicMock()
|
||||
ws2_transcript.status_code = 200
|
||||
ws2_transcript.json.return_value = {
|
||||
"callTranscripts": [_make_transcript("call_ws2")],
|
||||
"records": {},
|
||||
}
|
||||
ws2_details = MagicMock()
|
||||
ws2_details.status_code = 200
|
||||
ws2_details.json.return_value = {
|
||||
"calls": [_make_call_detail("call_ws2", "WS2 Call")]
|
||||
}
|
||||
|
||||
mock_request.side_effect = [
|
||||
ws1_transcript,
|
||||
ws1_details,
|
||||
ws2_transcript,
|
||||
ws2_details,
|
||||
]
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=["ws1_id", "ws2_id"],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
all_docs: list[Document] = []
|
||||
|
||||
# Checkpoint call 1 — workspace 1
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
all_docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert checkpoint.workspace_index == 1
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
# Checkpoint call 2 — workspace 2
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
all_docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(all_docs) == 2
|
||||
assert all_docs[0].semantic_identifier == "WS1 Call"
|
||||
assert all_docs[1].semantic_identifier == "WS2 Call"
|
||||
assert checkpoint.has_more is False
|
||||
assert checkpoint.workspace_index == 2
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_empty_workspace_404(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""404 from transcript API means no calls — workspace exhausted."""
|
||||
response_404 = MagicMock()
|
||||
response_404.status_code = 404
|
||||
|
||||
mock_request.return_value = response_404
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
next(generator)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert checkpoint.has_more is False
|
||||
assert checkpoint.workspace_index == 1
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_retry_only_fetches_missing_ids(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Retry for missing call details should only re-request the missing IDs."""
|
||||
transcript_response = MagicMock()
|
||||
transcript_response.status_code = 200
|
||||
transcript_response.json.return_value = {
|
||||
"callTranscripts": [
|
||||
_make_transcript("call1"),
|
||||
_make_transcript("call2"),
|
||||
],
|
||||
"records": {},
|
||||
}
|
||||
|
||||
# First fetch: returns call1 but not call2
|
||||
partial_details = MagicMock()
|
||||
partial_details.status_code = 200
|
||||
partial_details.json.return_value = {
|
||||
"calls": [_make_call_detail("call1", "Call One")]
|
||||
}
|
||||
|
||||
# Second fetch (retry): returns call2
|
||||
missing_details = MagicMock()
|
||||
missing_details.status_code = 200
|
||||
missing_details.json.return_value = {
|
||||
"calls": [_make_call_detail("call2", "Call Two")]
|
||||
}
|
||||
|
||||
mock_request.side_effect = [
|
||||
transcript_response,
|
||||
partial_details,
|
||||
missing_details,
|
||||
]
|
||||
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
with patch("onyx.connectors.gong.connector.time.sleep"):
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
assert len(docs) == 2
|
||||
assert docs[0].semantic_identifier == "Call One"
|
||||
assert docs[1].semantic_identifier == "Call Two"
|
||||
|
||||
# Verify: 3 API calls total (1 transcript + 1 full details + 1 retry for missing only)
|
||||
assert mock_request.call_count == 3
|
||||
# The retry call should only request call2, not both
|
||||
retry_call_body = mock_request.call_args_list[2][1]["json"]
|
||||
assert retry_call_body["filter"]["callIds"] == ["call2"]
|
||||
|
||||
@patch.object(GongConnector, "_throttled_request")
|
||||
def test_expired_cursor_restarts_workspace(
|
||||
self,
|
||||
mock_request: MagicMock,
|
||||
connector: GongConnector,
|
||||
) -> None:
|
||||
"""Expired pagination cursor resets checkpoint to restart the workspace."""
|
||||
expired_response = MagicMock()
|
||||
expired_response.status_code = 400
|
||||
expired_response.ok = False
|
||||
expired_response.text = '{"requestId":"abc","errors":["cursor has expired"]}'
|
||||
|
||||
mock_request.return_value = expired_response
|
||||
|
||||
# Checkpoint mid-pagination with a (now-expired) cursor
|
||||
checkpoint = GongConnectorCheckpoint(
|
||||
has_more=True,
|
||||
workspace_ids=[None],
|
||||
workspace_index=0,
|
||||
cursor="stale-cursor",
|
||||
)
|
||||
|
||||
docs: list[Document] = []
|
||||
generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
|
||||
try:
|
||||
while True:
|
||||
item = next(generator)
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
|
||||
assert len(docs) == 0
|
||||
# Cursor reset so next call restarts the workspace from scratch
|
||||
assert checkpoint.cursor is None
|
||||
assert checkpoint.workspace_index == 0
|
||||
assert checkpoint.has_more is True
|
||||
@@ -12,12 +12,14 @@ from unittest.mock import patch
|
||||
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.file_retrieval import DriveFileFieldType
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeSet
|
||||
|
||||
|
||||
def _make_done_checkpoint() -> GoogleDriveCheckpoint:
|
||||
@@ -198,3 +200,90 @@ class TestCeleryUtilsRouting:
|
||||
|
||||
mock_slim.assert_called_once()
|
||||
mock_perm_sync.assert_not_called()
|
||||
|
||||
|
||||
class TestFailedFolderIdsByEmail:
|
||||
def _make_failed_map(
|
||||
self, entries: dict[str, set[str]]
|
||||
) -> ThreadSafeDict[str, ThreadSafeSet[str]]:
|
||||
return ThreadSafeDict({k: ThreadSafeSet(v) for k, v in entries.items()})
|
||||
|
||||
def test_skips_api_call_for_known_failed_pair(self) -> None:
|
||||
"""_get_folder_metadata must skip the API call for a (folder, email) pair
|
||||
that previously confirmed no accessible parent."""
|
||||
connector = _make_connector()
|
||||
failed_map = self._make_failed_map(
|
||||
{
|
||||
"retriever@example.com": {"folder1"},
|
||||
"admin@example.com": {"folder1"},
|
||||
}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata"
|
||||
) as mock_api:
|
||||
result = connector._get_folder_metadata(
|
||||
folder_id="folder1",
|
||||
retriever_email="retriever@example.com",
|
||||
field_type=DriveFileFieldType.SLIM,
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
mock_api.assert_not_called()
|
||||
assert result is None
|
||||
|
||||
def test_records_failed_pair_when_no_parents(self) -> None:
|
||||
"""_get_folder_metadata must record (email → folder_id) in the map
|
||||
when the API returns a folder with no parents."""
|
||||
connector = _make_connector()
|
||||
failed_map: ThreadSafeDict[str, ThreadSafeSet[str]] = ThreadSafeDict()
|
||||
folder_no_parents: dict = {"id": "folder1", "name": "Orphaned"}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_drive_service",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata",
|
||||
return_value=folder_no_parents,
|
||||
),
|
||||
):
|
||||
connector._get_folder_metadata(
|
||||
folder_id="folder1",
|
||||
retriever_email="retriever@example.com",
|
||||
field_type=DriveFileFieldType.SLIM,
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
assert "folder1" in failed_map.get("retriever@example.com", ThreadSafeSet())
|
||||
assert "folder1" in failed_map.get("admin@example.com", ThreadSafeSet())
|
||||
|
||||
def test_does_not_record_when_parents_found(self) -> None:
|
||||
"""_get_folder_metadata must NOT record a pair when parents are found."""
|
||||
connector = _make_connector()
|
||||
failed_map: ThreadSafeDict[str, ThreadSafeSet[str]] = ThreadSafeDict()
|
||||
folder_with_parents: dict = {
|
||||
"id": "folder1",
|
||||
"name": "Normal",
|
||||
"parents": ["root"],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_drive_service",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.get_folder_metadata",
|
||||
return_value=folder_with_parents,
|
||||
),
|
||||
):
|
||||
connector._get_folder_metadata(
|
||||
folder_id="folder1",
|
||||
retriever_email="retriever@example.com",
|
||||
field_type=DriveFileFieldType.SLIM,
|
||||
failed_folder_ids_by_email=failed_map,
|
||||
)
|
||||
|
||||
assert len(failed_map) == 0
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Unit tests for SharepointConnector.load_credentials sp_tenant_domain resolution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
SITE_URL = "https://mytenant.sharepoint.com/sites/MySite"
|
||||
EXPECTED_TENANT_DOMAIN = "mytenant"
|
||||
|
||||
CLIENT_SECRET_CREDS = {
|
||||
"authentication_method": "client_secret",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_client_secret": "fake-client-secret",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
}
|
||||
|
||||
CERTIFICATE_CREDS = {
|
||||
"authentication_method": "certificate",
|
||||
"sp_client_id": "fake-client-id",
|
||||
"sp_directory_id": "fake-directory-id",
|
||||
"sp_private_key": base64.b64encode(b"fake-pfx-data").decode(),
|
||||
"sp_certificate_password": "fake-password",
|
||||
}
|
||||
|
||||
|
||||
def _make_mock_msal() -> MagicMock:
|
||||
mock_app = MagicMock()
|
||||
mock_app.acquire_token_for_client.return_value = {"access_token": "fake-token"}
|
||||
return mock_app
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_client_secret_without_site_pages_still_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
) -> None:
|
||||
"""client_secret auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CLIENT_SECRET_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_with_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=True must resolve sp_tenant_domain."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=True)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
|
||||
|
||||
@patch("onyx.connectors.sharepoint.connector.load_certificate_from_pfx")
|
||||
@patch("onyx.connectors.sharepoint.connector.msal.ConfidentialClientApplication")
|
||||
@patch("onyx.connectors.sharepoint.connector.GraphClient")
|
||||
def test_certificate_without_site_pages_sets_tenant_domain(
|
||||
_mock_graph_client: MagicMock,
|
||||
mock_msal_cls: MagicMock,
|
||||
mock_load_cert: MagicMock,
|
||||
) -> None:
|
||||
"""certificate auth + include_site_pages=False must still resolve sp_tenant_domain
|
||||
because _create_rest_client_context is also called for drive items."""
|
||||
mock_msal_cls.return_value = _make_mock_msal()
|
||||
mock_load_cert.return_value = MagicMock()
|
||||
connector = SharepointConnector(sites=[SITE_URL], include_site_pages=False)
|
||||
|
||||
connector.load_credentials(CERTIFICATE_CREDS)
|
||||
|
||||
assert connector.sp_tenant_domain == EXPECTED_TENANT_DOMAIN
|
||||
0
backend/tests/unit/onyx/connectors/web/__init__.py
Normal file
0
backend/tests/unit/onyx/connectors/web/__init__.py
Normal file
188
backend/tests/unit/onyx/connectors/web/test_slim_retrieval.py
Normal file
188
backend/tests/unit/onyx/connectors/web/test_slim_retrieval.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Unit tests for WebConnector.retrieve_all_slim_docs (slim pruning path)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
|
||||
BASE_URL = "http://example.com"
|
||||
|
||||
|
||||
def _make_context_mock(url_to_html: dict[str, str]) -> MagicMock:
|
||||
"""Return a BrowserContext mock whose pages respond based on the goto URL."""
|
||||
context = MagicMock()
|
||||
|
||||
def _new_page() -> MagicMock:
|
||||
page = MagicMock()
|
||||
visited: list[str] = []
|
||||
|
||||
def _goto(url: str, **kwargs: Any) -> MagicMock: # noqa: ARG001
|
||||
visited.append(url)
|
||||
page.url = url # no redirect
|
||||
response = MagicMock()
|
||||
response.status = 200
|
||||
response.header_value.return_value = None
|
||||
return response
|
||||
|
||||
def _content() -> str:
|
||||
return url_to_html.get(
|
||||
visited[-1] if visited else "", "<html><body></body></html>"
|
||||
)
|
||||
|
||||
page.goto.side_effect = _goto
|
||||
page.content.side_effect = _content
|
||||
return page
|
||||
|
||||
context.new_page.side_effect = _new_page
|
||||
return context
|
||||
|
||||
|
||||
def _make_playwright_mock(context: MagicMock) -> MagicMock: # noqa: ARG001
|
||||
playwright = MagicMock()
|
||||
playwright.stop = MagicMock()
|
||||
return playwright
|
||||
|
||||
|
||||
SINGLE_PAGE_HTML = (
|
||||
"<html><body><p>Content that should not appear in slim output</p></body></html>"
|
||||
)
|
||||
|
||||
RECURSIVE_ROOT_HTML = """
|
||||
<html><body>
|
||||
<a href="/page2">Page 2</a>
|
||||
<a href="/page3">Page 3</a>
|
||||
</body></html>
|
||||
"""
|
||||
|
||||
PAGE2_HTML = "<html><body><p>page 2</p></body></html>"
|
||||
PAGE3_HTML = "<html><body><p>page 3</p></body></html>"
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_yields_slim_documents(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""retrieve_all_slim_docs yields SlimDocuments with the correct URL as id."""
|
||||
context = _make_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(context), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
docs = [doc for batch in connector.retrieve_all_slim_docs() for doc in batch]
|
||||
|
||||
assert len(docs) == 1
|
||||
assert isinstance(docs[0], SlimDocument)
|
||||
assert docs[0].id == BASE_URL + "/"
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_skips_content_extraction(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""web_html_cleanup is never called in slim mode."""
|
||||
context = _make_context_mock({BASE_URL + "/": SINGLE_PAGE_HTML})
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(context), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
with patch("onyx.connectors.web.connector.web_html_cleanup") as mock_cleanup:
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
mock_cleanup.assert_not_called()
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_discovers_links_recursively(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""In RECURSIVE mode, internal <a href> links are followed and all URLs yielded."""
|
||||
url_to_html = {
|
||||
BASE_URL + "/": RECURSIVE_ROOT_HTML,
|
||||
BASE_URL + "/page2": PAGE2_HTML,
|
||||
BASE_URL + "/page3": PAGE3_HTML,
|
||||
}
|
||||
context = _make_context_mock(url_to_html)
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(context), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
|
||||
)
|
||||
|
||||
ids = {
|
||||
doc.id
|
||||
for batch in connector.retrieve_all_slim_docs()
|
||||
for doc in batch
|
||||
if isinstance(doc, SlimDocument)
|
||||
}
|
||||
|
||||
assert ids == {
|
||||
BASE_URL + "/",
|
||||
BASE_URL + "/page2",
|
||||
BASE_URL + "/page3",
|
||||
}
|
||||
|
||||
|
||||
@patch("onyx.connectors.web.connector.check_internet_connection")
|
||||
@patch("onyx.connectors.web.connector.requests.head")
|
||||
@patch("onyx.connectors.web.connector.start_playwright")
|
||||
def test_slim_render_wait_not_called_confirmed(
|
||||
mock_start_playwright: MagicMock,
|
||||
mock_head: MagicMock,
|
||||
_mock_check: MagicMock,
|
||||
) -> None:
|
||||
"""Confirm wait_for_timeout is called in full mode but not in slim mode."""
|
||||
pages_visited: list[MagicMock] = []
|
||||
|
||||
context = MagicMock()
|
||||
|
||||
def _new_page() -> MagicMock:
|
||||
page = MagicMock()
|
||||
page.url = BASE_URL + "/"
|
||||
response = MagicMock()
|
||||
response.status = 200
|
||||
response.header_value.return_value = None
|
||||
page.goto.return_value = response
|
||||
page.content.return_value = SINGLE_PAGE_HTML
|
||||
pages_visited.append(page)
|
||||
return page
|
||||
|
||||
context.new_page.side_effect = _new_page
|
||||
mock_start_playwright.return_value = (_make_playwright_mock(context), context)
|
||||
mock_head.return_value.headers = {"content-type": "text/html"}
|
||||
|
||||
connector = WebConnector(
|
||||
base_url=BASE_URL + "/",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
)
|
||||
|
||||
pages_visited.clear()
|
||||
list(connector.retrieve_all_slim_docs())
|
||||
for page in pages_visited:
|
||||
page.wait_for_timeout.assert_not_called()
|
||||
page.wait_for_load_state.assert_not_called()
|
||||
@@ -12,6 +12,10 @@ dependency on pypdf internals (pypdf.generic).
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import pdf_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.password_validation import is_pdf_protected
|
||||
@@ -96,6 +100,80 @@ class TestReadPdfFile:
|
||||
# Returned list is empty when callback is used
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_skips_images_above_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the embedded-image cap is exceeded, remaining images are skipped.
|
||||
|
||||
The cap protects the user-file-processing worker from OOMing on PDFs
|
||||
with thousands of embedded images. Setting the cap to 0 should yield
|
||||
zero extracted images even though the fixture has one.
|
||||
"""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_at_limit_extracts_up_to_cap(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A cap >= image count behaves identically to the uncapped path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_image_cap_with_callback_stops_streaming_at_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""The cap also short-circuits the streaming callback path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
collected: list[tuple[bytes, str]] = []
|
||||
|
||||
def callback(data: bytes, name: str) -> None:
|
||||
collected.append((data, name))
|
||||
|
||||
read_pdf_file(
|
||||
_load("with_image.pdf"), extract_images=True, image_callback=callback
|
||||
)
|
||||
assert collected == []
|
||||
|
||||
|
||||
# ── count_pdf_embedded_images ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountPdfEmbeddedImages:
|
||||
def test_returns_count_for_normal_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
|
||||
|
||||
def test_short_circuits_above_cap(self) -> None:
|
||||
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
|
||||
# function returns on first increment as the over-cap sentinel.
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
|
||||
|
||||
def test_returns_zero_for_pdf_without_images(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_invalid_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_password_locked_pdf(self) -> None:
|
||||
# encrypted.pdf has an open password; we can't inspect without it, so
|
||||
# the helper returns 0 — callers rely on the password-protected check
|
||||
# that runs earlier in the upload pipeline.
|
||||
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
|
||||
|
||||
def test_inspects_owner_password_only_pdf(self) -> None:
|
||||
# owner_protected.pdf is encrypted but has no open password. It should
|
||||
# decrypt with an empty string and count images normally. The fixture
|
||||
# has zero images, so 0 is a real count (not the "bail on encrypted"
|
||||
# path).
|
||||
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("with_image.pdf")
|
||||
pdf.seek(42)
|
||||
count_pdf_embedded_images(pdf, cap=10)
|
||||
assert pdf.tell() == 42
|
||||
|
||||
|
||||
# ── pdf_to_text ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
@@ -321,6 +322,17 @@ class TestXlsxSheetExtraction:
|
||||
sheets = xlsx_sheet_extraction(bad_file, file_name="~$temp.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_known_openpyxl_bug_max_value_returns_empty(self) -> None:
|
||||
"""openpyxl's strict descriptor validation rejects font family
|
||||
values >14 with 'Max value is 14'. Treat as a known openpyxl bug
|
||||
and skip the file rather than fail the whole connector batch."""
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.openpyxl.load_workbook",
|
||||
side_effect=ValueError("Max value is 14"),
|
||||
):
|
||||
sheets = xlsx_sheet_extraction(io.BytesIO(b""), file_name="bad_font.xlsx")
|
||||
assert sheets == []
|
||||
|
||||
def test_csv_content_matches_xlsx_to_text_per_sheet(self) -> None:
|
||||
"""For a single-sheet workbook, xlsx_to_text output should equal
|
||||
the csv_text from xlsx_sheet_extraction — they share the same
|
||||
|
||||
@@ -15,10 +15,19 @@ from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import TabularSection
|
||||
from onyx.indexing.chunking.section_chunker import AccumulatorState
|
||||
from onyx.indexing.chunking.tabular_section_chunker import TabularChunker
|
||||
from onyx.indexing.chunking.tabular_section_chunker.analysis import analyze_sheet
|
||||
from onyx.indexing.chunking.tabular_section_chunker.sheet_descriptor import (
|
||||
build_sheet_descriptor_chunks,
|
||||
)
|
||||
from onyx.indexing.chunking.tabular_section_chunker.total_descriptor import (
|
||||
build_total_descriptor_chunks,
|
||||
)
|
||||
from onyx.indexing.chunking.tabular_section_chunker.total_descriptor import (
|
||||
TOTALS_HEADER,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.csv_utils import parse_csv_string
|
||||
from onyx.utils.csv_utils import read_csv_header
|
||||
|
||||
|
||||
class CharTokenizer(BaseTokenizer):
|
||||
@@ -587,7 +596,7 @@ class TestTabularChunkerChunkSection:
|
||||
content_chunk = (
|
||||
"sheet:T\n" "Columns: Name, Age\n" "Name=Alice, Age=30\n" "Name=Bob, Age=25"
|
||||
)
|
||||
metadata_chunk = (
|
||||
descriptor_chunk = (
|
||||
"sheet:T\n"
|
||||
"Sheet overview.\n"
|
||||
"This sheet has 2 rows and 2 columns.\n"
|
||||
@@ -596,7 +605,17 @@ class TestTabularChunkerChunkSection:
|
||||
"Categorical columns (groupable, can be counted by value): Name\n"
|
||||
"Values seen in Name: Alice, Bob"
|
||||
)
|
||||
expected_texts = [content_chunk, metadata_chunk]
|
||||
totals_chunk = (
|
||||
"sheet:T\n"
|
||||
"Totals and overall aggregates across all rows. This sheet can answer "
|
||||
"whole-dataset questions about total, overall, grand total, sum across "
|
||||
"all, average, combined, mean, minimum, maximum, and count of values.\n"
|
||||
"Column Age: total (sum across all rows) = 55, average = 27.5, "
|
||||
"minimum = 25, maximum = 30, count = 2.\n"
|
||||
"Column Name most frequent value: Alice (1 occurrences).\n"
|
||||
"Total row count: 2."
|
||||
)
|
||||
expected_texts = [content_chunk, descriptor_chunk, totals_chunk]
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
out = _make_chunker_with_metadata().chunk_section(
|
||||
@@ -607,8 +626,8 @@ class TestTabularChunkerChunkSection:
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
assert [p.text for p in out.payloads] == expected_texts
|
||||
# Content first, metadata second — only the first chunk is fresh.
|
||||
assert [p.is_continuation for p in out.payloads] == [False, True]
|
||||
# Content first, metadata chunks follow as continuations.
|
||||
assert [p.is_continuation for p in out.payloads] == [False, True, True]
|
||||
|
||||
|
||||
class TestBuildSheetDescriptorChunks:
|
||||
@@ -627,9 +646,14 @@ class TestBuildSheetDescriptorChunks:
|
||||
heading: str | None = "sheet:T",
|
||||
max_tokens: int = 500,
|
||||
) -> list[str]:
|
||||
section = TabularSection(text=csv_text, link=_DEFAULT_LINK, heading=heading)
|
||||
parsed_rows = list(parse_csv_string(csv_text))
|
||||
headers = parsed_rows[0].header if parsed_rows else read_csv_header(csv_text)
|
||||
if not headers:
|
||||
return []
|
||||
return build_sheet_descriptor_chunks(
|
||||
section=section,
|
||||
headers=headers,
|
||||
analysis=analyze_sheet(headers, parsed_rows),
|
||||
heading=heading or "",
|
||||
tokenizer=CharTokenizer(),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -837,3 +861,174 @@ class TestBuildSheetDescriptorChunks:
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text, heading="", max_tokens=30) == expected
|
||||
|
||||
|
||||
class TestBuildTotalDescriptorChunks:
|
||||
"""Direct tests of `build_total_descriptor_chunks` — emits the totals
|
||||
chunk that names aggregate vocabulary (total/sum/average/min/max/
|
||||
count/most frequent) plus per-column aggregates so whole-dataset
|
||||
questions retrieve a chunk whose text actually contains the answer.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _build(
|
||||
csv_text: str,
|
||||
heading: str | None = "sheet:T",
|
||||
max_tokens: int = 1000,
|
||||
) -> list[str]:
|
||||
parsed_rows = list(parse_csv_string(csv_text))
|
||||
headers = parsed_rows[0].header if parsed_rows else read_csv_header(csv_text)
|
||||
if not headers:
|
||||
return []
|
||||
return build_total_descriptor_chunks(
|
||||
headers=headers,
|
||||
analysis=analyze_sheet(headers, parsed_rows),
|
||||
heading=heading or "",
|
||||
tokenizer=CharTokenizer(),
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
def test_numeric_and_categorical_columns_emit_every_line(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# amount → numeric (total=600, avg=200, min=100, max=300, count=3)
|
||||
# region → categorical (US appears twice, EU once → top=US (2))
|
||||
csv_text = "amount,region\n100,US\n200,EU\n300,US\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected = [
|
||||
"sheet:T\n"
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column amount: total (sum across all rows) = 600, average = 200, "
|
||||
"minimum = 100, maximum = 300, count = 3.\n"
|
||||
"Column region most frequent value: US (2 occurrences).\n"
|
||||
"Total row count: 3."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text) == expected
|
||||
|
||||
def test_numeric_only_sheet_has_no_categorical_line(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Both columns are all-numeric → no "most frequent value" lines.
|
||||
csv_text = "x,y\n1,2\n3,4\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected = [
|
||||
"sheet:T\n"
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column x: total (sum across all rows) = 4, average = 2, "
|
||||
"minimum = 1, maximum = 3, count = 2.\n"
|
||||
"Column y: total (sum across all rows) = 6, average = 3, "
|
||||
"minimum = 2, maximum = 4, count = 2.\n"
|
||||
"Total row count: 2."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text) == expected
|
||||
|
||||
def test_categorical_only_sheet_has_no_numeric_line(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Non-numeric low-cardinality column → categorical only. "red"
|
||||
# wins over "blue" 2-to-1.
|
||||
csv_text = "color\nred\nblue\nred\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected = [
|
||||
"sheet:T\n"
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column color most frequent value: red (2 occurrences).\n"
|
||||
"Total row count: 3."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text) == expected
|
||||
|
||||
def test_underscored_column_names_get_friendly_alias(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Underscored headers get the same `name (name with spaces)` alias
|
||||
# used elsewhere so retrieval matches either surface form.
|
||||
csv_text = "total_cost\n100\n200\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected = [
|
||||
"sheet:T\n"
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column total_cost (total cost): total (sum across all rows) = 300, "
|
||||
"average = 150, minimum = 100, maximum = 200, count = 2.\n"
|
||||
"Total row count: 2."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text) == expected
|
||||
|
||||
def test_non_integer_averages_format_with_decimals(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Whole-number inputs but a fractional average. `_fmt` drops the
|
||||
# ".0" when the value is integral and falls back to `:.6g` when
|
||||
# it isn't — verify both surfaces on the same line.
|
||||
csv_text = "rate\n1\n2\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
# total=3 (int), avg=1.5 (fractional), min=1, max=2, count=2.
|
||||
expected = [
|
||||
"sheet:T\n"
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column rate: total (sum across all rows) = 3, average = 1.5, "
|
||||
"minimum = 1, maximum = 2, count = 2.\n"
|
||||
"Total row count: 2."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text) == expected
|
||||
|
||||
def test_empty_section_returns_no_chunks(self) -> None:
|
||||
# No parsed rows → no totals to report; builder bails out early.
|
||||
assert self._build("") == []
|
||||
|
||||
def test_header_only_csv_returns_no_chunks(self) -> None:
|
||||
# Header-only CSV yields zero data rows → `parse_csv_string`
|
||||
# returns nothing, so the builder returns an empty list.
|
||||
assert self._build("col1,col2\n") == []
|
||||
|
||||
def test_no_heading_omits_prefix_line(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# heading=None → prefix is just TOTALS_HEADER, no leading heading
|
||||
# line in the emitted chunk.
|
||||
csv_text = "n\n5\n"
|
||||
|
||||
# --- EXPECTED --------------------------------------------------
|
||||
expected = [
|
||||
f"{TOTALS_HEADER}\n"
|
||||
"Column n: total (sum across all rows) = 5, average = 5, "
|
||||
"minimum = 5, maximum = 5, count = 1.\n"
|
||||
"Total row count: 1."
|
||||
]
|
||||
|
||||
# --- ACT / ASSERT ---------------------------------------------
|
||||
assert self._build(csv_text, heading=None) == expected
|
||||
|
||||
def test_tight_budget_splits_into_multiple_chunks_each_with_header(self) -> None:
|
||||
# --- INPUT -----------------------------------------------------
|
||||
# Three numeric columns under a tight budget force pack_lines to
|
||||
# split across multiple chunks. Every emitted chunk must still
|
||||
# start with `heading + TOTALS_HEADER` so retrieval keeps context
|
||||
# on whichever chunk wins.
|
||||
csv_text = "a,b,c\n1,2,3\n4,5,6\n"
|
||||
|
||||
# --- ACT -------------------------------------------------------
|
||||
# Budget chosen so the three aggregate lines can't all fit under
|
||||
# TOTALS_HEADER in a single chunk.
|
||||
out = self._build(csv_text, heading="S", max_tokens=len(TOTALS_HEADER) + 120)
|
||||
|
||||
# --- ASSERT ----------------------------------------------------
|
||||
# Split actually happened.
|
||||
assert len(out) > 1
|
||||
# Each chunk carries the full prefix (heading + totals header).
|
||||
assert all(c.startswith(f"S\n{TOTALS_HEADER}\n") for c in out)
|
||||
# Collectively, every per-column aggregate and the row count line
|
||||
# must appear somewhere in the output.
|
||||
body = "\n".join(out)
|
||||
assert "Column a: total (sum across all rows) = 5" in body
|
||||
assert "Column b: total (sum across all rows) = 7" in body
|
||||
assert "Column c: total (sum across all rows) = 9" in body
|
||||
assert "Total row count: 2." in body
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.llm.utils import get_max_input_tokens
|
||||
VERTEX_OPUS_MODELS_REJECTING_OUTPUT_CONFIG = [
|
||||
"claude-opus-4-5@20251101",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-7",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Tests for ``ImageGenerationTool._resolve_reference_image_file_ids``.
|
||||
|
||||
The resolver turns the LLM's ``reference_image_file_ids`` argument into a
|
||||
cleaned list of file IDs to hand to ``_load_reference_images``. It trusts
|
||||
the LLM's picks — the LLM can only see file IDs that actually appear in
|
||||
the conversation (via ``[attached image — file_id: <id>]`` tags on user
|
||||
messages and the JSON returned by prior generate_image calls), so we
|
||||
don't re-validate against an allow-list in the tool itself.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD,
|
||||
)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
supports_reference_images: bool = True,
|
||||
max_reference_images: int = 16,
|
||||
) -> ImageGenerationTool:
|
||||
"""Construct a tool with a mock provider so no credentials/network are needed."""
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.images.image_generation_tool.get_image_generation_provider"
|
||||
) as mock_get_provider:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.supports_reference_images = supports_reference_images
|
||||
mock_provider.max_reference_images = max_reference_images
|
||||
mock_get_provider.return_value = mock_provider
|
||||
|
||||
return ImageGenerationTool(
|
||||
image_generation_credentials=MagicMock(),
|
||||
tool_id=1,
|
||||
emitter=MagicMock(),
|
||||
model="gpt-image-1",
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
|
||||
class TestResolveReferenceImageFileIds:
|
||||
def test_unset_returns_empty_plain_generation(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool._resolve_reference_image_file_ids(llm_kwargs={}) == []
|
||||
|
||||
def test_empty_list_is_treated_like_unset(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: []},
|
||||
)
|
||||
assert result == []
|
||||
|
||||
def test_passes_llm_supplied_ids_through(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["upload-1", "gen-1"]},
|
||||
)
|
||||
# Order preserved — first entry is the primary edit source.
|
||||
assert result == ["upload-1", "gen-1"]
|
||||
|
||||
def test_invalid_shape_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: "not-a-list"},
|
||||
)
|
||||
|
||||
def test_non_string_element_raises(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["ok", 123]},
|
||||
)
|
||||
|
||||
def test_deduplicates_preserving_first_occurrence(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1", "gen-2", "gen-1"]},
|
||||
)
|
||||
assert result == ["gen-1", "gen-2"]
|
||||
|
||||
def test_strips_whitespace_and_skips_empty_strings(self) -> None:
|
||||
tool = _make_tool()
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: [" gen-1 ", "", " "]},
|
||||
)
|
||||
assert result == ["gen-1"]
|
||||
|
||||
def test_provider_without_reference_support_raises(self) -> None:
|
||||
tool = _make_tool(supports_reference_images=False)
|
||||
with pytest.raises(ToolCallException):
|
||||
tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["gen-1"]},
|
||||
)
|
||||
|
||||
def test_truncates_to_provider_max_preserving_head(self) -> None:
|
||||
"""When the LLM lists more images than the provider allows, keep the
|
||||
HEAD of the list (the primary edit source + earliest extras) rather
|
||||
than the tail, since the LLM put the most important one first."""
|
||||
tool = _make_tool(max_reference_images=2)
|
||||
result = tool._resolve_reference_image_file_ids(
|
||||
llm_kwargs={REFERENCE_IMAGE_FILE_IDS_FIELD: ["a", "b", "c", "d"]},
|
||||
)
|
||||
assert result == ["a", "b"]
|
||||
@@ -1,10 +1,5 @@
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
|
||||
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
|
||||
from onyx.tools.tool_runner import _merge_tool_calls
|
||||
|
||||
|
||||
@@ -312,62 +307,3 @@ class TestMergeToolCalls:
|
||||
assert len(result) == 1
|
||||
# String should be converted to list item
|
||||
assert result[0].tool_args["queries"] == ["single_query", "q2"]
|
||||
|
||||
|
||||
class TestImageHistoryExtraction:
|
||||
def test_extracts_image_file_ids_from_json_response(self) -> None:
|
||||
msg = '[{"file_id":"img-1","revised_prompt":"v1"},{"file_id":"img-2","revised_prompt":"v2"}]'
|
||||
assert _extract_image_file_ids_from_tool_response_message(msg) == [
|
||||
"img-1",
|
||||
"img-2",
|
||||
]
|
||||
|
||||
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="generate_image",
|
||||
tool_arguments={"prompt": "test"},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
|
||||
|
||||
def test_ignores_non_image_tool_responses(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_arguments={"queries": ["q"]},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == []
|
||||
|
||||
@@ -129,12 +129,36 @@ class TestWorkerHealthCollector:
|
||||
up = families[1]
|
||||
assert up.name == "onyx_celery_worker_up"
|
||||
assert len(up.samples) == 3
|
||||
# Labels use short names (before @)
|
||||
labels = {s.labels["worker"] for s in up.samples}
|
||||
assert labels == {"primary", "docfetching", "monitoring"}
|
||||
label_pairs = {
|
||||
(s.labels["worker_type"], s.labels["hostname"]) for s in up.samples
|
||||
}
|
||||
assert label_pairs == {
|
||||
("primary", "host1"),
|
||||
("docfetching", "host1"),
|
||||
("monitoring", "host1"),
|
||||
}
|
||||
for sample in up.samples:
|
||||
assert sample.value == 1
|
||||
|
||||
def test_replicas_of_same_worker_type_are_distinct(self) -> None:
|
||||
"""Regression: ``docprocessing@pod-1`` and ``docprocessing@pod-2`` must
|
||||
produce separate samples, not collapse into one duplicate-timestamp
|
||||
series.
|
||||
"""
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-1"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-2"})
|
||||
monitor._on_heartbeat({"hostname": "docprocessing@pod-3"})
|
||||
|
||||
collector = WorkerHealthCollector(cache_ttl=0)
|
||||
collector.set_monitor(monitor)
|
||||
|
||||
up = collector.collect()[1]
|
||||
assert len(up.samples) == 3
|
||||
hostnames = {s.labels["hostname"] for s in up.samples}
|
||||
assert hostnames == {"pod-1", "pod-2", "pod-3"}
|
||||
assert all(s.labels["worker_type"] == "docprocessing" for s in up.samples)
|
||||
|
||||
def test_reports_dead_worker(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
monitor._on_heartbeat({"hostname": "primary@host1"})
|
||||
@@ -151,9 +175,9 @@ class TestWorkerHealthCollector:
|
||||
assert active.samples[0].value == 1
|
||||
|
||||
up = families[1]
|
||||
samples_by_name = {s.labels["worker"]: s.value for s in up.samples}
|
||||
assert samples_by_name["primary"] == 1
|
||||
assert samples_by_name["monitoring"] == 0
|
||||
samples_by_type = {s.labels["worker_type"]: s.value for s in up.samples}
|
||||
assert samples_by_type["primary"] == 1
|
||||
assert samples_by_type["monitoring"] == 0
|
||||
|
||||
def test_empty_monitor_returns_zero(self) -> None:
|
||||
monitor = WorkerHeartbeatMonitor(MagicMock())
|
||||
|
||||
@@ -58,8 +58,7 @@ SERVICE_ORDER=(
|
||||
validate_template() {
|
||||
local template_file=$1
|
||||
echo "Validating template: $template_file..."
|
||||
aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null
|
||||
if [ $? -ne 0 ]; then
|
||||
if ! aws cloudformation validate-template --template-body file://"$template_file" --region "$AWS_REGION" > /dev/null; then
|
||||
echo "Error: Validation failed for $template_file. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
@@ -108,13 +107,15 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
# Create temporary parameters file for this template
|
||||
local temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
local temp_params_file
|
||||
temp_params_file=$(create_parameters_from_json "$template_file")
|
||||
|
||||
# Special handling for SubnetIDs parameter if needed
|
||||
if grep -q "SubnetIDs" "$template_file"; then
|
||||
echo "Template uses SubnetIDs parameter, ensuring it's properly formatted..."
|
||||
# Make sure we're passing SubnetIDs as a comma-separated list
|
||||
local subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
local subnet_ids
|
||||
subnet_ids=$(remove_comments "$CONFIG_FILE" | jq -r '.SubnetIDs // empty')
|
||||
if [ -n "$subnet_ids" ]; then
|
||||
echo "Using SubnetIDs from config: $subnet_ids"
|
||||
else
|
||||
@@ -123,15 +124,13 @@ deploy_stack() {
|
||||
fi
|
||||
|
||||
echo "Deploying stack: $stack_name with template: $template_file and generated config from: $CONFIG_FILE..."
|
||||
aws cloudformation deploy \
|
||||
if ! aws cloudformation deploy \
|
||||
--stack-name "$stack_name" \
|
||||
--template-file "$template_file" \
|
||||
--parameter-overrides file://"$temp_params_file" \
|
||||
--capabilities CAPABILITY_IAM CAPABILITY_NAMED_IAM CAPABILITY_AUTO_EXPAND \
|
||||
--region "$AWS_REGION" \
|
||||
--no-cli-auto-prompt > /dev/null
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
--no-cli-auto-prompt > /dev/null; then
|
||||
echo "Error: Deployment failed for $stack_name. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -52,11 +52,9 @@ delete_stack() {
|
||||
--region "$AWS_REGION"
|
||||
|
||||
echo "Waiting for stack $stack_name to be deleted..."
|
||||
aws cloudformation wait stack-delete-complete \
|
||||
if aws cloudformation wait stack-delete-complete \
|
||||
--stack-name "$stack_name" \
|
||||
--region "$AWS_REGION"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
--region "$AWS_REGION"; then
|
||||
echo "Stack $stack_name deleted successfully."
|
||||
sleep 10
|
||||
else
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#!/bin/sh
|
||||
# fill in the template
|
||||
export ONYX_BACKEND_API_HOST="${ONYX_BACKEND_API_HOST:-api_server}"
|
||||
export ONYX_WEB_SERVER_HOST="${ONYX_WEB_SERVER_HOST:-web_server}"
|
||||
@@ -16,12 +17,15 @@ echo "Using web server host: $ONYX_WEB_SERVER_HOST"
|
||||
echo "Using MCP server host: $ONYX_MCP_SERVER_HOST"
|
||||
echo "Using nginx proxy timeouts - connect: ${NGINX_PROXY_CONNECT_TIMEOUT}s, send: ${NGINX_PROXY_SEND_TIMEOUT}s, read: ${NGINX_PROXY_READ_TIMEOUT}s"
|
||||
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME $ONYX_BACKEND_API_HOST $ONYX_WEB_SERVER_HOST $ONYX_MCP_SERVER_HOST $NGINX_PROXY_CONNECT_TIMEOUT $NGINX_PROXY_SEND_TIMEOUT $NGINX_PROXY_READ_TIMEOUT' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
|
||||
|
||||
# Conditionally create MCP server configuration
|
||||
if [ "${MCP_SERVER_ENABLED}" = "True" ] || [ "${MCP_SERVER_ENABLED}" = "true" ]; then
|
||||
echo "MCP server is enabled, creating MCP configuration..."
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$ONYX_MCP_SERVER_HOST' < "/etc/nginx/conf.d/mcp_upstream.conf.inc.template" > /etc/nginx/conf.d/mcp_upstream.conf.inc
|
||||
# shellcheck disable=SC2016
|
||||
envsubst '$ONYX_MCP_SERVER_HOST' < "/etc/nginx/conf.d/mcp.conf.inc.template" > /etc/nginx/conf.d/mcp.conf.inc
|
||||
else
|
||||
echo "MCP server is disabled, removing MCP configuration..."
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user