Compare commits

..

2 Commits

Author SHA1 Message Date
Adam Serafin
8f5e685704 fix: initialize tracing in mcp 2026-04-17 12:16:27 +02:00
Adam Serafin
63cdd2c53a recursively extract all text from ADF 2026-04-15 10:20:18 +02:00
487 changed files with 4549 additions and 13849 deletions

View File

@@ -1,7 +1,6 @@
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
curl \
default-jre \
fd-find \
@@ -62,11 +61,3 @@ RUN chsh -s /bin/zsh root && \
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
done && \
chown dev:dev /home/dev/.zshrc
# Pre-seed GitHub's SSH host keys so git-over-SSH never prompts. Keys are
# pinned in-repo (verified against the fingerprints GitHub publishes at
# https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/githubs-ssh-key-fingerprints)
# rather than fetched at build time, so a compromised build-time network can't
# inject a rogue key.
COPY github_known_hosts /etc/ssh/ssh_known_hosts
RUN chmod 644 /etc/ssh/ssh_known_hosts

View File

@@ -1,11 +1,7 @@
{
"name": "Onyx Dev Sandbox",
"image": "onyxdotapp/onyx-devcontainer@sha256:4986c9252289b660ce772b45f0488b938fe425d8114245e96ef64b273b3fcee4",
"runArgs": [
"--cap-add=NET_ADMIN",
"--cap-add=NET_RAW",
"--network=onyx_default"
],
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW"],
"mounts": [
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
@@ -16,13 +12,10 @@
"source=onyx-devcontainer-local,target=/home/dev/.local,type=volume"
],
"containerEnv": {
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock",
"POSTGRES_HOST": "relational_db",
"REDIS_HOST": "cache"
"SSH_AUTH_SOCK": "/tmp/ssh-agent.sock"
},
"remoteUser": "${localEnv:DEVCONTAINER_REMOTE_USER:dev}",
"updateRemoteUserUID": false,
"initializeCommand": "docker network create onyx_default 2>/dev/null || true",
"workspaceMount": "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=delegated",
"workspaceFolder": "/workspace",
"postStartCommand": "sudo bash /workspace/.devcontainer/init-dev-user.sh && sudo bash /workspace/.devcontainer/init-firewall.sh",

View File

@@ -1,3 +0,0 @@
github.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXkE2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMjA2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIEs4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+EjqoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/WnwH6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk=
github.com ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEmKSENjQEezOmxkZMy7opKgwFB9nkt5YRrYMjNuG5N87uRgg6CLrbo5wAdT/y6v0mKV0U2w0WZ2YB/++Tpockg=
github.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl

View File

@@ -4,23 +4,22 @@ set -euo pipefail
echo "Setting up firewall..."
# Reset default policies to ACCEPT before flushing rules. On re-runs the
# previous invocation's DROP policies are still in effect; flushing rules while
# the default is DROP would block the DNS lookups below. Register a trap so
# that if the script exits before the DROP policies are re-applied at the end,
# we fail closed instead of leaving the container with an unrestricted
# firewall.
trap 'iptables -P INPUT DROP; iptables -P OUTPUT DROP; iptables -P FORWARD DROP' EXIT
iptables -P INPUT ACCEPT
iptables -P OUTPUT ACCEPT
iptables -P FORWARD ACCEPT
# Preserve docker dns resolution
DOCKER_DNS_RULES=$(iptables-save | grep -E "^-A.*-d 127.0.0.11/32" || true)
# Only flush the filter table. The nat and mangle tables are managed by Docker
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
# flushing them breaks Docker's embedded DNS resolver.
# Flush all rules
iptables -t nat -F
iptables -t nat -X
iptables -t mangle -F
iptables -t mangle -X
iptables -F
iptables -X
# Restore docker dns rules
if [ -n "$DOCKER_DNS_RULES" ]; then
echo "$DOCKER_DNS_RULES" | iptables-restore -n
fi
# Create ipset for allowed destinations
ipset create allowed-domains hash:net || true
ipset flush allowed-domains
@@ -35,7 +34,6 @@ done
# Resolve allowed domains
ALLOWED_DOMAINS=(
"github.com"
"registry.npmjs.org"
"api.anthropic.com"
"api-staging.anthropic.com"
@@ -45,16 +43,8 @@ ALLOWED_DOMAINS=(
"pypi.org"
"files.pythonhosted.org"
"go.dev"
"proxy.golang.org"
"sum.golang.org"
"storage.googleapis.com"
"dl.google.com"
"static.rust-lang.org"
"index.crates.io"
"static.crates.io"
"archive.ubuntu.com"
"security.ubuntu.com"
"deb.nodesource.com"
)
for domain in "${ALLOWED_DOMAINS[@]}"; do
@@ -75,14 +65,6 @@ if [ -n "$DOCKER_GATEWAY" ]; then
fi
fi
# Allow traffic to all attached Docker network subnets so the container can
# reach sibling services (e.g. relational_db, cache) on shared compose networks.
for subnet in $(ip -4 -o addr show scope global | awk '{print $4}'); do
if ! ipset add allowed-domains "$subnet" -exist 2>&1; then
echo "warning: failed to add Docker subnet $subnet to allowlist" >&2
fi
done
# Set default policies to DROP
iptables -P FORWARD DROP
iptables -P INPUT DROP

View File

@@ -462,7 +462,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -472,7 +472,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
@@ -536,7 +536,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -546,7 +546,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
@@ -597,7 +597,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -676,7 +676,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -686,7 +686,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
@@ -761,7 +761,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -771,7 +771,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
@@ -833,7 +833,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -908,7 +908,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -918,7 +918,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -981,7 +981,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -991,7 +991,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -1041,7 +1041,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -1119,7 +1119,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -1129,7 +1129,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -1192,7 +1192,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -1202,7 +1202,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -1253,7 +1253,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -1329,7 +1329,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
@@ -1341,7 +1341,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
@@ -1409,7 +1409,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
with:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
@@ -1421,7 +1421,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
env:
DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }}
with:
@@ -1475,7 +1475,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3

View File

@@ -115,7 +115,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -127,7 +127,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -175,7 +175,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -187,7 +187,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
@@ -220,7 +220,7 @@ jobs:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit

View File

@@ -94,7 +94,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -105,7 +105,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Web Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile
@@ -155,7 +155,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -166,7 +166,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
@@ -216,7 +216,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -227,7 +227,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server

View File

@@ -19,16 +19,16 @@ permissions:
jobs:
mypy-check:
# See https://runs-on.com/runners/linux/
# NOTE: This job is named mypy-check for branch protection compatibility,
# but it actually runs ty (astral-sh's Rust type checker).
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
runner=2cpu-linux-x64,
"run-id=${{ github.run_id }}-mypy-check",
"extras=s3-cache",
]
timeout-minutes: 15
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -46,7 +46,26 @@ jobs:
backend/requirements/model_server.txt
backend/requirements/ee.txt
- name: Run ty
- name: Generate OpenAPI schema and Python client
shell: bash
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
env:
LICENSE_ENFORCEMENT_ENABLED: "false"
run: |
ods openapi all
- name: Cache mypy cache
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: ty check --output-format github
run: mypy .

View File

@@ -17,6 +17,8 @@ env:
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ vars.AZURE_API_URL }}
@@ -69,7 +71,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
- name: Build and load
uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0

View File

@@ -132,7 +132,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -142,7 +142,7 @@ jobs:
- name: Build and push AMD64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
@@ -202,7 +202,7 @@ jobs:
latest=false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
@@ -212,7 +212,7 @@ jobs:
- name: Build and push ARM64
id: build
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # ratchet:docker/build-push-action@v7
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
@@ -258,7 +258,7 @@ jobs:
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3

View File

@@ -67,11 +67,12 @@ repos:
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
pass_filenames: true
files: ^backend/(?!\.venv/|scripts/).*\.py$
- id: uv-run
name: ty
args: ["ty", "check"]
pass_filenames: true
types_or: [python]
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
# args: ["--all-extras", "mypy"]
# pass_filenames: true
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
@@ -141,7 +142,6 @@ repos:
hooks:
- id: ripsecrets
args:
- --strict-ignore
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$

View File

@@ -1 +0,0 @@
.devcontainer/github_known_hosts

12
.vscode/launch.json vendored
View File

@@ -475,18 +475,6 @@
"order": 0
}
},
{
"name": "Start Monitoring Stack (Prometheus + Grafana)",
"type": "node",
"request": "launch",
"runtimeExecutable": "docker",
"runtimeArgs": ["compose", "up", "-d"],
"cwd": "${workspaceFolder}/profiling",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",

View File

@@ -63,13 +63,11 @@ Your features must pass all tests and all comments must be addressed prior to me
### Implicit agreements
If we approve an issue, we are promising you the following:
- Your work will receive timely attention and we will put aside other important items to ensure you are not blocked.
- You will receive necessary coaching on eng quality, system design, etc. to ensure the feature is completed well.
- The Onyx team will pull resources and bandwidth from design, PM, and engineering to ensure that you have all the resources to build the feature to the quality required for merging.
Because this is a large investment from our team, we ask that you:
- Thoroughly read all the requirements of the design docs, engineering best practices, and try to minimize overhead for the Onyx team.
- Complete the feature in a timely manner to reduce context switching and an ongoing resource pull from the Onyx team.
@@ -151,10 +149,10 @@ Set up pre-commit hooks (black / reorder-python-imports):
uv run pre-commit install
```
We also use `ty` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the ty checks manually:
We also use `mypy` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the mypy checks manually:
```bash
uv run ty check
uv run mypy . # from onyx/backend
```
#### Frontend
@@ -194,7 +192,6 @@ Before starting, make sure the Docker Daemon is running.
> **Note:** "Clear and Restart External Volumes and Containers" will reset your Postgres and OpenSearch (relational-db and index). Only run this if you are okay with wiping your data.
**Features:**
- Hot reload is enabled for the web server and API servers
- Python debugging is configured with debugpy
- Environment variables are loaded from `.vscode/.env`
@@ -347,16 +344,13 @@ sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
### Style and Maintainability
#### Comments and readability
Add clear comments:
- At logical boundaries (e.g., interfaces) so the reader doesn't need to dig 10 layers deeper.
- Wherever assumptions are made or something non-obvious/unexpected is done.
- For complicated flows/functions.
- Wherever it saves time (e.g., nontrivial regex patterns).
#### Errors and exceptions
- **Fail loudly** rather than silently skipping work.
- Example: raise and let exceptions propagate instead of silently dropping a document.
- **Don't overuse `try/except`.**
@@ -364,7 +358,6 @@ Add clear comments:
- Do not mask exceptions unless it is clearly appropriate.
#### Typing
- Everything should be **as strictly typed as possible**.
- Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`).
- Only `cast` when the type checker sees `Any` or types are too loose.
@@ -375,7 +368,6 @@ Add clear comments:
- `dict[EmbeddingModel, list[EmbeddingVector]]`
#### State, objects, and boundaries
- Keep **clear logical boundaries** for state containers and objects.
- A **config** object should never contain things like a `db_session`.
- Avoid state containers that are overly nested, or huge + flat (use judgment).
@@ -388,7 +380,6 @@ Add clear comments:
- Prefer **hash maps (dicts)** over tree structures unless there's a strong reason.
#### Naming
- Name variables carefully and intentionally.
- Prefer long, explicit names when undecided.
- Avoid single-character variables except for small, self-contained utilities (or not at all).
@@ -399,7 +390,6 @@ Add clear comments:
- IntelliSense can miss call sites; search works best with unique names.
#### Correctness by construction
- Prefer self-contained correctness — don't rely on callers to "use it right" if you can make misuse hard.
- Avoid redundancies: if a function takes an arg, it shouldn't also take a state object that contains that same arg.
- No dead code (unless there's a very good reason).
@@ -427,35 +417,29 @@ Add clear comments:
### Repository Conventions
#### Where code lives
- Pydantic + data models: `models.py` files.
- DB interface functions (excluding lazy loading): `db/` directory.
- LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them.
- API routes: `server/` directory.
#### Pydantic and modeling
- Prefer **Pydantic** over dataclasses.
- If absolutely required, use `allow_arbitrary_types`.
#### Data conventions
- Prefer explicit `None` over sentinel empty strings (usually; depends on intent).
- Prefer explicit identifiers: use string enums instead of integer codes.
- Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.**
#### Logging
- Log messages where they are created.
- Don't propagate log messages around just to log them elsewhere.
#### Encapsulation
- Don't use private attributes/methods/properties from other classes/modules.
- "Private" is private — respect that boundary.
#### SQLAlchemy guidance
- Lazy loading is often bad at scale, especially across multiple list relationships.
- Be careful when accessing SQLAlchemy object attributes:
- It can help avoid redundant DB queries,
@@ -464,7 +448,6 @@ Add clear comments:
- Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/
#### Trunk-based development and feature flags
- **PRs should contain no more than 500 lines of real change.**
- **Merge to main frequently.** Avoid long-lived feature branches — they create merge conflicts and integration pain.
- **Use feature flags for incremental rollout.**
@@ -475,7 +458,6 @@ Add clear comments:
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
#### Miscellaneous
- Any TODOs you add in the code must be accompanied by either the name/username of the owner of that TODO, or an issue number for an issue referencing that piece of work.
- Avoid module-level logic that runs on import, which leads to import-time side effects. Essentially every piece of meaningful logic should exist within some function that has to be explicitly invoked. Acceptable exceptions may include loading environment variables or setting up loggers.
- If you find yourself needing something like this, you may want that logic to exist in a file dedicated for manual execution (contains `if __name__ == "__main__":`) which should not be imported by anything else.

View File

@@ -50,10 +50,6 @@ COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
COPY ./onyx/utils/tenant.py /app/onyx/utils/tenant.py
# Sentry configuration (used when SENTRY_DSN is set)
COPY ./onyx/configs/__init__.py /app/onyx/configs/__init__.py
COPY ./onyx/configs/sentry.py /app/onyx/configs/sentry.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -26,9 +26,7 @@ from shared_configs.configs import (
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
from celery.backends.database.session import ( # ty: ignore[unresolved-import]
ResultModelBase,
)
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine.sql_engine import SqlEngine
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be

View File

@@ -49,7 +49,7 @@ def upgrade() -> None:
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
server_onupdate=sa.text("now()"), # ty: ignore[invalid-argument-type]
server_onupdate=sa.text("now()"), # type: ignore
nullable=True,
),
sa.Column(

View File

@@ -68,7 +68,7 @@ def upgrade() -> None:
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
).fetchone()
tool_id = result[0] # ty: ignore[not-subscriptable]
tool_id = result[0] # type: ignore
# Associate the tool with all existing personas
# Get all persona IDs

View File

@@ -52,7 +52,7 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(),
default=lambda: datetime.datetime.now(datetime.timezone.utc),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",

View File

@@ -10,7 +10,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy import text
from typing import cast
from typing import cast, Any
from botocore.exceptions import ClientError
@@ -255,7 +255,7 @@ def _migrate_files_to_external_storage() -> None:
continue
lobj_id = cast(int, file_record.lobj_oid)
file_metadata = file_record.file_metadata
file_metadata = cast(Any, file_record.file_metadata)
# Read file content from PostgreSQL
try:

View File

@@ -112,7 +112,7 @@ def _get_access_for_documents(
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
is_public=is_public_anywhere, # ty: ignore[invalid-argument-type]
is_public=is_public_anywhere,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
)

View File

@@ -1,6 +1,5 @@
import os
from datetime import datetime
from datetime import timezone
import jwt
from fastapi import Depends
@@ -59,7 +58,7 @@ def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
payload = {
"tenant_id": tenant_id,
# Token does not expire
"iat": datetime.now(timezone.utc), # Issued at time
"iat": datetime.utcnow(), # Issued at time
}
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")

View File

@@ -80,7 +80,6 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
@@ -209,11 +208,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever doc-permission sync has any due cc_pairs to dispatch.
if cc_pair_ids_to_sync:
maybe_mark_tenant_active(tenant_id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
payload_id = try_creating_permissions_sync_task(

View File

@@ -69,7 +69,6 @@ from onyx.redis.redis_connector_ext_group_sync import (
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
@@ -203,11 +202,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever external-group sync has any due cc_pairs to dispatch.
if cc_pair_ids_to_sync:
maybe_mark_tenant_active(tenant_id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
payload_id = try_creating_external_group_sync_task(

View File

@@ -53,7 +53,7 @@ def fetch_query_analytics(
.order_by(cast(ChatMessage.time_sent, Date))
)
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
return db_session.execute(stmt).all() # type: ignore
def fetch_per_user_query_analytics(
@@ -92,7 +92,7 @@ def fetch_per_user_query_analytics(
.order_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id)
)
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
return db_session.execute(stmt).all() # type: ignore
def fetch_onyxbot_analytics(

View File

@@ -9,7 +9,7 @@ logger = setup_logger()
def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]:
sources = db_session.query(distinct(Connector.source)).all()
sources = db_session.query(distinct(Connector.source)).all() # type: ignore
document_sources = [source[0] for source in sources]

View File

@@ -128,9 +128,9 @@ def get_used_seats(tenant_id: str | None = None) -> int:
select(func.count())
.select_from(User)
.where(
User.is_active == True, # noqa: E712
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
User.email != ANONYMOUS_USER_EMAIL,
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
User.account_type != AccountType.SERVICE_ACCOUNT,
)
)

View File

@@ -121,7 +121,7 @@ class ScimDAL(DAL):
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now()
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
@@ -229,7 +229,7 @@ class ScimDAL(DAL):
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # ty: ignore[invalid-argument-type]
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
@@ -293,22 +293,16 @@ class ScimDAL(DAL):
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(
query, User.email, scim_filter # ty: ignore[invalid-argument-type]
)
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_( # ty: ignore[unresolved-attribute]
scim_filter.value.lower() == "true"
)
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(
User.id == mapping.user_id # ty: ignore[invalid-argument-type]
)
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
@@ -324,9 +318,7 @@ class ScimDAL(DAL):
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id) # ty: ignore[invalid-argument-type]
.offset(offset)
.limit(count)
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
)
.unique()
.all()
@@ -585,7 +577,7 @@ class ScimDAL(DAL):
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter)
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
@@ -623,9 +615,7 @@ class ScimDAL(DAL):
users = (
self._session.scalars(
select(User).where(
User.id.in_(user_ids) # ty: ignore[unresolved-attribute]
)
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
@@ -650,9 +640,7 @@ class ScimDAL(DAL):
return []
existing_users = (
self._session.scalars(
select(User).where(
User.id.in_(uuids) # ty: ignore[unresolved-attribute]
)
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()

View File

@@ -300,11 +300,8 @@ def fetch_user_groups_for_user(
stmt = (
select(UserGroup)
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
.join(
User,
User.id == User__UserGroup.user_id, # ty: ignore[invalid-argument-type]
)
.where(User.id == user_id) # ty: ignore[invalid-argument-type]
.join(User, User.id == User__UserGroup.user_id) # type: ignore
.where(User.id == user_id) # type: ignore
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
@@ -433,7 +430,7 @@ def fetch_user_groups_for_documents(
.group_by(Document.id)
)
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
return db_session.execute(stmt).all() # type: ignore
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
@@ -807,9 +804,7 @@ def update_user_group(
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(
select(User).where(
User.id.in_(removed_user_ids) # ty: ignore[unresolved-attribute]
)
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
# Filter out admin and global curator users before validating curator status

View File

@@ -1,6 +1,6 @@
from collections.abc import Iterator
from googleapiclient.discovery import Resource
from googleapiclient.discovery import Resource # type: ignore
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
@@ -38,7 +38,7 @@ def get_folder_permissions_by_ids(
A list of permissions matching the provided permission IDs
"""
return get_permissions_by_ids(
drive_service=service, # ty: ignore[invalid-argument-type]
drive_service=service,
doc_id=folder_id,
permission_ids=permission_ids,
)
@@ -68,7 +68,7 @@ def get_modified_folders(
# Retrieve and yield folders
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from googleapiclient.errors import HttpError
from googleapiclient.errors import HttpError # type: ignore
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
@@ -183,7 +183,7 @@ def _get_drive_members(
)
admin_user_info = (
admin_service.users() # ty: ignore[unresolved-attribute]
admin_service.users()
.get(userKey=google_drive_connector.primary_admin_email)
.execute()
)
@@ -197,7 +197,7 @@ def _get_drive_members(
try:
for permission in execute_paginated_retrieval(
drive_service.permissions().list, # ty: ignore[unresolved-attribute]
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type),nextPageToken",
@@ -256,7 +256,7 @@ def _get_all_google_groups(
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list, # ty: ignore[unresolved-attribute]
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email),nextPageToken",
@@ -274,7 +274,7 @@ def _google_group_to_onyx_group(
"""
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list, # ty: ignore[unresolved-attribute]
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email),nextPageToken",
@@ -298,7 +298,7 @@ def _map_group_email_to_member_emails(
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list, # ty: ignore[unresolved-attribute]
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email),nextPageToken",

View File

@@ -33,7 +33,7 @@ def get_permissions_by_ids(
# Fetch all permissions for the document
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list, # ty: ignore[unresolved-attribute]
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails),nextPageToken",

View File

@@ -68,7 +68,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw) # ty: ignore[invalid-argument-type]
permission = Permission(**raw_perm.raw)
# We only care about ability to browse through projects + issues (not other permissions such as read/write).
if permission.permission != "BROWSE_PROJECTS":

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (

View File

@@ -7,11 +7,11 @@ from typing import Any
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient
from office365.onedrive.driveitems.driveItem import DriveItem
from office365.runtime.client_request import ClientRequestException
from office365.sharepoint.client_context import ClientContext
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup

View File

@@ -46,10 +46,9 @@ def get_query_analytics(
daily_query_usage_info = fetch_query_analytics(
start=start
or (
datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
), # default is 30d lookback
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
end=end or datetime.datetime.utcnow(),
db_session=db_session,
)
return [
@@ -78,10 +77,9 @@ def get_user_analytics(
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
start=start
or (
datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
), # default is 30d lookback
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
end=end or datetime.datetime.utcnow(),
db_session=db_session,
)
@@ -113,10 +111,9 @@ def get_onyxbot_analytics(
daily_onyxbot_info = fetch_onyxbot_analytics(
start=start
or (
datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
), # default is 30d lookback
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
end=end or datetime.datetime.utcnow(),
db_session=db_session,
)
@@ -149,10 +146,9 @@ def get_persona_messages(
) -> list[PersonaMessageAnalyticsResponse]:
"""Fetch daily message counts for a single persona within the given time range."""
start = start or (
datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
)
end = end or datetime.datetime.now(tz=datetime.timezone.utc)
end = end or datetime.datetime.utcnow()
persona_message_counts = []
for count, date in fetch_persona_message_analytics(
@@ -230,10 +226,9 @@ def get_assistant_stats(
along with the overall total messages and total distinct users.
"""
start = start or (
datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
)
end = end or datetime.datetime.now(tz=datetime.timezone.utc)
end = end or datetime.datetime.utcnow()
if not user_can_view_assistant_stats(db_session, user, assistant_id):
raise HTTPException(

View File

@@ -287,10 +287,8 @@ def update_hook(
validated_is_reachable: bool | None = None
if endpoint_url_changing or api_key_changing or timeout_changing:
existing = _get_hook_or_404(db_session, hook_id)
effective_url: str = ( # ty: ignore[invalid-assignment]
req.endpoint_url
if endpoint_url_changing
else existing.endpoint_url # endpoint_url is required on create and cannot be cleared on update
effective_url: str = (
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
)
effective_api_key: str | None = (
(api_key if not isinstance(api_key, UnsetType) else None)
@@ -301,10 +299,8 @@ def update_hook(
else None
)
)
effective_timeout: float = ( # ty: ignore[invalid-assignment]
req.timeout_seconds
if timeout_changing
else existing.timeout_seconds # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
effective_timeout: float = (
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
)
validation = _validate_endpoint(
endpoint_url=effective_url,

View File

@@ -97,7 +97,7 @@ def fetch_and_process_chat_session_history(
break
paged_snapshots = parallel_yield(
[ # ty: ignore[invalid-argument-type]
[
yield_snapshot_from_chat_session(
db_session=db_session,
chat_session=chat_session,

View File

@@ -1,6 +1,5 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import jwt
from fastapi import HTTPException
@@ -20,8 +19,8 @@ def generate_data_plane_token() -> str:
payload = {
"iss": "data_plane",
"exp": datetime.now(tz=timezone.utc) + timedelta(minutes=5),
"iat": datetime.now(tz=timezone.utc),
"exp": datetime.utcnow() + timedelta(minutes=5),
"iat": datetime.utcnow(),
"scope": "api_access",
}

View File

@@ -55,10 +55,8 @@ def run_alembic_migrations(schema_name: str) -> None:
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # ty: ignore[invalid-assignment]
alembic_cfg.cmd_opts.x = [ # ty: ignore[invalid-assignment]
f"schemas={schema_name}"
]
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")

View File

@@ -349,9 +349,8 @@ def get_tenant_count(tenant_id: str) -> int:
user_count = (
db_session.query(User)
.filter(
User.email.in_(emails), # ty: ignore[unresolved-attribute]
User.is_active # noqa: E712 # ty: ignore[invalid-argument-type]
== True,
User.email.in_(emails), # type: ignore
User.is_active == True, # type: ignore # noqa: E712
)
.count()
)

View File

@@ -73,7 +73,7 @@ def capture_and_sync_with_alternate_posthog(
cloud_props.pop("onyx_cloud_user_id", None)
posthog.identify(
distinct_id=cloud_user_id, # ty: ignore[possibly-unresolved-reference]
distinct_id=cloud_user_id,
properties=cloud_props,
)
except Exception as e:
@@ -105,7 +105,7 @@ def get_anon_id_from_request(request: Any) -> str | None:
if (cookie_value := request.cookies.get(cookie_name)) and (
parsed := parse_posthog_cookie(cookie_value)
):
return parsed.get("distinct_id") # ty: ignore[possibly-unresolved-reference]
return parsed.get("distinct_id")
return None

View File

@@ -23,7 +23,7 @@
# from shared_configs.model_server_models import IntentResponse
# if TYPE_CHECKING:
# from setfit import SetFitModel
# from setfit import SetFitModel # type: ignore[import-untyped]
# from transformers import PreTrainedTokenizer, BatchEncoding
@@ -423,7 +423,7 @@
# def map_keywords(
# input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
# ) -> list[str]:
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
# if not len(tokens) == len(is_keyword):
# raise ValueError("Length of tokens and keyword predictions must match")

View File

@@ -18,7 +18,7 @@
# super().__init__()
# config = DistilBertConfig()
# self.distilbert = DistilBertModel(config)
# config = self.distilbert.config
# config = self.distilbert.config # type: ignore
# # Keyword tokenwise binary classification layer
# self.keyword_classifier = nn.Linear(config.dim, 2)
@@ -85,7 +85,7 @@
# self.config = config
# self.distilbert = DistilBertModel(config)
# config = self.distilbert.config
# config = self.distilbert.config # type: ignore
# self.connector_global_classifier = nn.Linear(config.dim, 1)
# self.connector_match_classifier = nn.Linear(config.dim, 1)
# self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

View File

@@ -7,8 +7,8 @@ from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
import sendgrid
from sendgrid.helpers.mail import Attachment
import sendgrid # type: ignore
from sendgrid.helpers.mail import Attachment # type: ignore
from sendgrid.helpers.mail import Content
from sendgrid.helpers.mail import ContentId
from sendgrid.helpers.mail import Disposition

View File

@@ -10,7 +10,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-missing-import]
from jwt.algorithms import RSAAlgorithm
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
from onyx.utils.logger import setup_logger

View File

@@ -46,10 +46,8 @@ async def _test_expire_oauth_token(
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account( # ty: ignore[invalid-argument-type]
user, # ty: ignore[invalid-argument-type]
cast(Any, oauth_account),
updated_data,
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
@@ -134,10 +132,8 @@ async def refresh_oauth_token(
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account( # ty: ignore[invalid-argument-type]
user, # ty: ignore[invalid-argument-type]
cast(Any, oauth_account),
updated_data,
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")

View File

@@ -191,7 +191,7 @@ class OAuthTokenManager:
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False) # ty: ignore[invalid-return-type]
return value.get_value(apply_mask=False)
return value
@staticmethod
@@ -199,7 +199,5 @@ class OAuthTokenManager:
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
) -> dict[str, Any]:
if isinstance(token_data, SensitiveValue):
return token_data.get_value( # ty: ignore[invalid-return-type]
apply_mask=False
)
return token_data.get_value(apply_mask=False)
return token_data

View File

@@ -121,7 +121,5 @@ def require_permission(
return user
dependency._is_require_permission = ( # ty: ignore[unresolved-attribute]
True # sentinel for auth_check detection
)
dependency._is_require_permission = True # type: ignore[attr-defined] # sentinel for auth_check detection
return dependency

View File

@@ -45,9 +45,7 @@ from fastapi_users import UUIDIDMixin
from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import (
RedisStrategy, # ty: ignore[possibly-missing-import]
)
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
@@ -464,16 +462,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.user_db = tenant_user_db
if hasattr(user_create, "role"):
user_create.role = UserRole.BASIC # ty: ignore[invalid-assignment]
user_create.role = UserRole.BASIC
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = ( # ty: ignore[invalid-assignment]
UserRole.ADMIN
)
user_create.role = UserRole.ADMIN
# Check seat availability for new users (single-tenant only)
with get_session_with_current_tenant() as sync_db:
@@ -520,9 +516,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get( # ty: ignore[invalid-assignment]
user_id
)
user = await self.user_db.get(user_id) # type: ignore[assignment]
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
@@ -550,9 +544,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get( # ty: ignore[invalid-assignment]
user_id
)
user = await self.user_db.get(user_id) # type: ignore[assignment]
if user_created:
await self._assign_default_pinned_assistants(user, db_session)
remove_user_from_invited_users(user_create.email)
@@ -600,11 +592,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
update nor the group assignment is visible without the other.
"""
with get_session_with_current_tenant() as sync_db:
sync_user = (
sync_db.query(User)
.filter(User.id == user_id) # ty: ignore[invalid-argument-type]
.first()
)
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
if sync_user:
sync_user.hashed_password = self.password_helper.hash(
user_create.password
@@ -625,9 +613,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_id,
)
async def validate_password( # ty: ignore[invalid-method-override]
self, password: str, _: schemas.UC | models.UP
) -> None:
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to configurable security policy (defined via environment variables)
if len(password) < PASSWORD_MIN_LENGTH:
raise exceptions.InvalidPasswordException(
@@ -658,7 +644,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return
@log_function_time(print_only=True)
async def oauth_callback( # ty: ignore[invalid-method-override]
async def oauth_callback(
self,
oauth_name: str,
access_token: str,
@@ -768,7 +754,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # ty: ignore[invalid-argument-type]
existing_oauth_account, # type: ignore
oauth_account_dict,
)
@@ -802,11 +788,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = (
sync_db.query(User)
.filter(User.id == user.id) # ty: ignore[invalid-argument-type]
.first()
)
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
@@ -826,7 +808,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # ty: ignore[invalid-assignment]
user.oidc_expiry = None # type: ignore
remove_user_from_invited_users(user.email)
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -943,11 +925,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value))
):
marketing_anonymous_id = (
parsed_cookie[ # ty: ignore[possibly-unresolved-reference]
"distinct_id"
]
)
marketing_anonymous_id = parsed_cookie["distinct_id"]
# Technically, USER_SIGNED_UP is only fired from the cloud site when
# it is the first user in a tenant. However, it is semantically correct
@@ -964,10 +942,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
}
# Add all other values from the marketing cookie (featureFlags, etc.)
for (
key,
value,
) in parsed_cookie.items(): # ty: ignore[possibly-unresolved-reference]
for key, value in parsed_cookie.items():
if key != "distinct_id":
properties.setdefault(key, value)
@@ -1529,7 +1504,7 @@ async def _sync_jwt_oidc_expiry(
if user.oidc_expiry is not None:
await user_manager.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # ty: ignore[invalid-assignment]
user.oidc_expiry = None # type: ignore
async def _get_or_create_user_from_jwt(
@@ -2257,7 +2232,7 @@ def get_oauth_router(
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback( # ty: ignore[invalid-argument-type]
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,

View File

@@ -6,16 +6,16 @@ from typing import Any
from typing import cast
import sentry_sdk
from celery import bootsteps # ty: ignore[unresolved-import]
from celery import bootsteps # type: ignore
from celery import Task
from celery.app import trace # ty: ignore[unresolved-import]
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import before_task_publish
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # ty: ignore[unresolved-import]
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text

View File

@@ -3,7 +3,7 @@ from typing import Any
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # ty: ignore[unresolved-import]
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
from celery.utils.log import get_task_logger

View File

@@ -4,4 +4,4 @@ import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -29,7 +29,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -100,7 +100,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -30,7 +30,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -106,7 +106,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -27,7 +27,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.heavy")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -92,7 +92,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -29,7 +29,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.light")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -95,26 +95,19 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(
f"Concurrency: {sender.concurrency}" # ty: ignore[unresolved-attribute]
)
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(
pool_size=sender.concurrency, # ty: ignore[unresolved-attribute]
max_overflow=EXTRA_CONCURRENCY,
)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # ty: ignore[unresolved-attribute]
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY # ty: ignore[unresolved-attribute]
)
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -20,7 +20,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -3,7 +3,7 @@ import os
from typing import Any
from typing import cast
from celery import bootsteps # ty: ignore[unresolved-import]
from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
@@ -38,12 +38,6 @@ from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -52,7 +46,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -65,7 +59,6 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
@signals.task_postrun.connect
@@ -80,31 +73,6 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -117,7 +85,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(
pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW
)
@@ -177,7 +145,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock # ty: ignore[unresolved-attribute]
sender.primary_worker_lock = lock # type: ignore
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
@@ -244,7 +212,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("primary")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -22,7 +22,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.user_file_processing")
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -66,7 +66,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -179,7 +179,7 @@ def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str
# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # ty: ignore[invalid-assignment]
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
@@ -208,9 +208,7 @@ def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = ( # ty: ignore[invalid-assignment]
inspect.reserved()
)
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
@@ -231,9 +229,7 @@ def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
active_tasks: dict[str, list] | None = ( # ty: ignore[invalid-assignment]
inspect.active()
)
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:

View File

@@ -6,7 +6,6 @@ from celery.schedules import crontab
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
@@ -227,7 +226,7 @@ if SCHEDULED_EVAL_DATASET_NAMES:
)
# Add OpenSearch migration task if enabled.
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
beat_task_templates.append(
{
"name": "migrate-chunks-from-vespa-to-opensearch",

View File

@@ -59,7 +59,6 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
@@ -166,22 +165,12 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids and note whether any are in DELETING status
# collect cc_pair_ids
cc_pair_ids: list[int] = []
has_deleting_cc_pair = False
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
has_deleting_cc_pair = True
# Tenant-work-gating hook: mark only when at least one cc_pair is in
# DELETING status. Marking on bare cc_pair existence would keep
# nearly every tenant in the active set since most have cc_pairs
# but almost none are actively being deleted on any given cycle.
if has_deleting_cc_pair:
maybe_mark_tenant_active(tenant_id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:

View File

@@ -34,7 +34,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
@@ -471,15 +470,6 @@ def docfetching_proxy_task(
index_attempt.connector_credential_pair.connector.source.value
)
cc_pair = index_attempt.connector_credential_pair
on_index_attempt_status_change(
tenant_id=tenant_id,
source=result.connector_source,
cc_pair_id=cc_pair_id,
connector_name=cc_pair.connector.name or f"cc_pair_{cc_pair_id}",
status="in_progress",
)
while True:
sleep(5)

View File

@@ -108,11 +108,7 @@ from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.redis.redis_utils import is_fence
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
@@ -528,25 +524,13 @@ def check_indexing_completion(
# Update CC pair status if successful
cc_pair = get_connector_credential_pair_from_id(
db_session,
attempt.connector_credential_pair_id,
eager_load_connector=True,
db_session, attempt.connector_credential_pair_id
)
if cc_pair is None:
raise RuntimeError(
f"CC pair {attempt.connector_credential_pair_id} not found in database"
)
source = cc_pair.connector.source.value
connector_name = cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
on_index_attempt_status_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
status=attempt.status.value,
)
if attempt.status.is_successful():
# NOTE: we define the last successful index time as the time the last successful
# attempt finished. This is distinct from the poll_range_end of the last successful
@@ -567,15 +551,6 @@ def check_indexing_completion(
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
)
on_connector_indexing_success(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
docs_indexed=attempt.new_docs_indexed or 0,
success_timestamp=attempt.time_updated.timestamp(),
)
# Clear repeated error state on success
if cc_pair.in_repeated_error_state:
cc_pair.in_repeated_error_state = False
@@ -595,13 +570,6 @@ def check_indexing_completion(
db_session.delete(notif)
db_session.commit()
on_connector_error_state_change(
tenant_id=tenant_id,
source=source,
cc_pair_id=cc_pair.id,
connector_name=connector_name,
in_error=False,
)
if attempt.status == IndexingStatus.SUCCESS:
logger.info(
@@ -811,7 +779,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# redis_client_celery: Redis = self.app.broker_connection().channel().client
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
@@ -925,16 +893,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
in_repeated_error_state=True,
)
error_connector_name = (
cc_pair.connector.name or f"cc_pair_{cc_pair.id}"
)
on_connector_error_state_change(
tenant_id=tenant_id,
source=cc_pair.connector.source.value,
cc_pair_id=cc_pair_id,
connector_name=error_connector_name,
in_error=True,
)
connector_name = (
cc_pair.name
@@ -962,6 +920,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
f"connector={cc_pair.connector.name} "
f"source={source}"
)
# When entering repeated error state, also pause the connector
# to prevent continued indexing retry attempts burning through embedding credits.
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
@@ -1014,14 +973,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
f"Skipping secondary indexing: switchover_type=INSTANT for search_settings={secondary_search_settings.id}"
)
# Tenant-work-gating hook: refresh membership only when indexing
# actually dispatched at least one docfetching task. `_kickoff_indexing_tasks`
# internally calls `should_index()` to decide per-cc_pair; using
# `tasks_created > 0` here gives us a "real work was done" signal
# rather than just "tenant has a cc_pair somewhere."
if tasks_created > 0:
maybe_mark_tenant_active(tenant_id)
# 2/3: VALIDATE
# Check for inconsistent index attempts - active attempts without task IDs
# This can happen if attempt creation fails partway through

View File

@@ -51,6 +51,7 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries
from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
@@ -72,7 +73,6 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
@@ -229,7 +229,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
prune_dispatched = False
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
@@ -252,18 +251,9 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
logger.info(f"Pruning not created: {cc_pair_id}")
continue
prune_dispatched = True
task_logger.info(
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
)
# Tenant-work-gating hook: mark only when at least one cc_pair
# was actually due for pruning AND a prune task was dispatched.
# Marking on bare cc_pair existence over-counts the population
# since most tenants have cc_pairs but almost none are due on
# any given cycle.
if prune_dispatched:
maybe_mark_tenant_active(tenant_id)
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=_get_pruning_block_expiration())
# we want to run this less frequently than the overall task
@@ -653,6 +643,16 @@ def connector_pruning_generator_task(
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
diff_start = time.monotonic()
try:
# a list of docs in our local index

View File

@@ -248,7 +248,6 @@ def document_by_cc_pair_cleanup_task(
),
)
mark_document_as_modified(document_id, db_session)
db_session.commit()
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)

View File

@@ -15,7 +15,6 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document import construct_document_id_select_by_needs_sync
from onyx.db.document import count_documents_by_needs_sync
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
from onyx.utils.logger import setup_logger
# Redis keys for document sync tracking
@@ -151,10 +150,6 @@ def try_generate_stale_document_sync_tasks(
logger.info("No stale documents found. Skipping sync tasks generation.")
return None
# Tenant-work-gating hook: refresh this tenant's active-set membership
# whenever vespa sync actually has stale docs to dispatch.
maybe_mark_tenant_active(tenant_id)
logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
)

View File

@@ -61,9 +61,7 @@ def load_checkpoint(
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
return connector.validate_checkpoint_json( # ty: ignore[invalid-return-type]
checkpoint_data
)
return connector.validate_checkpoint_json(checkpoint_data)
return ConnectorCheckpoint.model_validate_json(checkpoint_data)

View File

@@ -1164,10 +1164,7 @@ def run_llm_loop(
emitter.emit(
Packet(
placement=Placement(
turn_index=llm_cycle_count # ty: ignore[possibly-unresolved-reference]
+ reasoning_cycles
),
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
obj=OverallStop(type="stop"),
)
)

View File

@@ -826,12 +826,6 @@ def translate_history_to_llm_format(
base64_data = img_file.to_base64()
image_url = f"data:{image_type};base64,{base64_data}"
content_parts.append(
TextContentPart(
type="text",
text=f"[attached image — file_id: {img_file.file_id}]",
)
)
image_part = ImageContentPart(
type="image_url",
image_url=ImageUrlDetail(

View File

@@ -324,9 +324,6 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
DISABLE_OPENSEARCH_MIGRATION_TASK = (
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
)
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
@@ -843,29 +840,6 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Maximum embedded images allowed in a single file. PDFs (and other formats)
# with thousands of embedded images can OOM the user-file-processing worker
# because every image is decoded with PIL and then sent to the vision LLM.
# Enforced both at upload time (rejects the file) and during extraction
# (defense-in-depth: caps the number of images materialized).
#
# Clamped to >= 0; a negative env value would turn upload validation into
# always-fail and extraction into always-stop, which is never desired. 0
# disables image extraction entirely, which is a valid (if aggressive) setting.
MAX_EMBEDDED_IMAGES_PER_FILE = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
)
# Maximum embedded images allowed across all files in a single upload batch.
# Protects against the scenario where a user uploads many files that each
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
# worker under concurrency or run up surprise latency/cost. Also clamped
# to >= 0.
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
)
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
@@ -1151,32 +1125,6 @@ DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Master switch for the tenant work-gating feature. Controls the `enabled`
# axis only — flipping this True puts the feature in shadow mode (compute
# the gate, log skip counts, but do not actually skip). The `enforce` axis
# is Redis-only with a hard-coded default of False, so this env flag alone
# cannot cause real tenants to be skipped. Default off.
ENABLE_TENANT_WORK_GATING = (
os.environ.get("ENABLE_TENANT_WORK_GATING", "").lower() == "true"
)
# Membership TTL for the `active_tenants` sorted set. Members older than this
# are treated as inactive by the gate read path. Must be > the full-fanout
# interval so self-healing re-adds a genuinely-working tenant before their
# membership expires. Default 30 min.
TENANT_WORK_GATING_TTL_SECONDS = int(
os.environ.get("TENANT_WORK_GATING_TTL_SECONDS", 30 * 60)
)
# Minimum wall-clock interval between full-fanout cycles. When this many
# seconds have elapsed since the last bypass, the generator ignores the gate
# on the next invocation and dispatches to every non-gated tenant, letting
# consumers re-populate the active set. Schedule-independent so beat drift
# or backlog can't make the self-heal bursty or sparse. Default 20 min.
TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS = int(
os.environ.get("TENANT_WORK_GATING_FULL_FANOUT_INTERVAL_SECONDS", 20 * 60)
)
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(

View File

@@ -639,11 +639,9 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[
socket.TCP_KEEPALIVE # ty: ignore[unresolved-attribute]
] = 60
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore[attr-defined,unused-ignore]
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined,unused-ignore]
class OnyxCallTypes(str, Enum):

View File

@@ -547,7 +547,7 @@ class AirtableConnector(LoadConnector):
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
future_to_record[ # ty: ignore[invalid-assignment]
future_to_record[
executor.submit(
current_context.run,
self._process_record,

View File

@@ -3,7 +3,7 @@ from collections.abc import Iterator
from datetime import datetime
from typing import Dict
import asana
import asana # type: ignore
from onyx.utils.logger import setup_logger

View File

@@ -290,8 +290,8 @@ class AxeroConnector(PollConnector):
if not self.axero_key or not self.base_url:
raise ConnectorMissingCredentialError("Axero")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
entity_types = []
if self.include_article:
@@ -327,7 +327,7 @@ class AxeroConnector(PollConnector):
)
all_axero_forums = _map_post_to_parent(
posts=forums_posts, # ty: ignore[invalid-argument-type]
posts=forums_posts,
api_key=self.axero_key,
axero_base_url=self.base_url,
)

View File

@@ -26,10 +26,6 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
@@ -42,7 +38,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -76,9 +71,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
self.bucket_region: Optional[str] = None
self.european_residency: bool = european_residency
def set_allow_images( # ty: ignore[invalid-method-override]
self, allow_images: bool
) -> None:
def set_allow_images(self, allow_images: bool) -> None:
"""Set whether to process images in this connector."""
logger.info(f"Setting allow_images to {allow_images}.")
self._allow_images = allow_images
@@ -197,9 +190,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
method="sts-assume-role",
)
botocore_session = get_session()
botocore_session._credentials = ( # ty: ignore[unresolved-attribute]
refreshable
)
botocore_session._credentials = refreshable # type: ignore[attr-defined]
session = boto3.Session(botocore_session=botocore_session)
self.s3_client = session.client("s3")
elif authentication_method == "assume_role":
@@ -460,40 +451,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
logger.exception(f"Error processing image {key}")
continue
# Handle tabular files (xlsx, csv, tsv) — produce one
# TabularSection per sheet (or per file for csv/tsv)
# instead of a flat TextSection.
if is_tabular_file(file_name):
try:
downloaded_file = self._download_object(key)
if downloaded_file is None:
continue
tabular_sections = tabular_file_to_sections(
BytesIO(downloaded_file),
file_name=file_name,
link=link,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=(
tabular_sections
if tabular_sections
else [TabularSection(link=link, text="")]
),
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing tabular file {key}")
continue
# Handle text and document files
try:
downloaded_file = self._download_object(key)

View File

@@ -2,7 +2,6 @@ import html
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -57,14 +56,14 @@ class BookstackConnector(LoadConnector, PollConnector):
}
if start:
params["filter[updated_at:gte]"] = datetime.fromtimestamp(
start, tz=timezone.utc
params["filter[updated_at:gte]"] = datetime.utcfromtimestamp(
start
).strftime("%Y-%m-%d")
if end:
params["filter[updated_at:lte]"] = datetime.fromtimestamp(
end, tz=timezone.utc
).strftime("%Y-%m-%d")
params["filter[updated_at:lte]"] = datetime.utcfromtimestamp(end).strftime(
"%Y-%m-%d"
)
batch = bookstack_client.get(endpoint, params=params).get("data", [])
doc_batch: list[Document | HierarchyNode] = [

View File

@@ -27,19 +27,16 @@ _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = {
401: OnyxErrorCode.CREDENTIAL_EXPIRED,
403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
404: OnyxErrorCode.BAD_GATEWAY,
429: OnyxErrorCode.RATE_LIMITED,
}
def _error_code_for_status(status_code: int) -> OnyxErrorCode:
"""Map an HTTP status code to the appropriate OnyxErrorCode.
Expects a >= 400 status code. Known codes (401, 403, 404) are
Expects a >= 400 status code. Known codes (401, 403, 404, 429) are
mapped to specific error codes; all other codes (unrecognised 4xx
and 5xx) map to BAD_GATEWAY as unexpected upstream errors.
Note: 429 is intentionally omitted — the rl_requests wrapper
handles rate limits transparently at the HTTP layer, so 429
responses never reach this function.
"""
if status_code in _STATUS_TO_ERROR_CODE:
return _STATUS_TO_ERROR_CODE[status_code]

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from datetime import timezone
from enum import StrEnum
from typing import Any
from typing import cast
from typing import Literal
from typing import NoReturn
from typing import TypeAlias
from pydantic import BaseModel
from retry import retry
@@ -24,11 +25,8 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.error_handling.exceptions import OnyxError
@@ -49,6 +47,10 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
raise InsufficientPermissionsError(
"Canvas API token does not have sufficient permissions (HTTP 403)."
)
elif e.status_code == 429:
raise ConnectorValidationError(
"Canvas rate-limit exceeded (HTTP 429). Please try again later."
)
elif e.status_code >= 500:
raise UnexpectedValidationError(
f"Unexpected Canvas HTTP error (status={e.status_code}): {e}"
@@ -59,60 +61,6 @@ def _handle_canvas_api_error(e: OnyxError) -> NoReturn:
)
class CanvasStage(StrEnum):
PAGES = "pages"
ASSIGNMENTS = "assignments"
ANNOUNCEMENTS = "announcements"
_STAGE_CONFIG: dict[CanvasStage, dict[str, Any]] = {
CanvasStage.PAGES: {
"endpoint": "courses/{course_id}/pages",
"params": {
"per_page": "100",
"include[]": "body",
"published": "true",
"sort": "updated_at",
"order": "desc",
},
},
CanvasStage.ASSIGNMENTS: {
"endpoint": "courses/{course_id}/assignments",
"params": {"per_page": "100", "published": "true"},
},
CanvasStage.ANNOUNCEMENTS: {
"endpoint": "announcements",
"params": {
"per_page": "100",
"context_codes[]": "course_{course_id}",
"active_only": "true",
},
},
}
def _parse_canvas_dt(timestamp_str: str) -> datetime:
"""Parse a Canvas ISO-8601 timestamp (e.g. '2025-06-15T12:00:00Z')
into a timezone-aware UTC datetime.
Canvas returns timestamps with a trailing 'Z' instead of '+00:00',
so we normalise before parsing.
"""
return datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")).astimezone(
timezone.utc
)
def _unix_to_canvas_time(epoch: float) -> str:
"""Convert a Unix timestamp to Canvas ISO-8601 format (e.g. '2025-06-15T12:00:00Z')."""
return datetime.fromtimestamp(epoch, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _in_time_window(timestamp_str: str, start: float, end: float) -> bool:
"""Check whether a Canvas ISO-8601 timestamp falls within (start, end]."""
return start < _parse_canvas_dt(timestamp_str).timestamp() <= end
class CanvasCourse(BaseModel):
id: int
name: str | None = None
@@ -197,6 +145,9 @@ class CanvasAnnouncement(BaseModel):
)
CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"]
class CanvasConnectorCheckpoint(ConnectorCheckpoint):
"""Checkpoint state for resumable Canvas indexing.
@@ -214,30 +165,15 @@ class CanvasConnectorCheckpoint(ConnectorCheckpoint):
course_ids: list[int] = []
current_course_index: int = 0
stage: CanvasStage = CanvasStage.PAGES
stage: CanvasStage = "pages"
next_url: str | None = None
def advance_course(self) -> None:
"""Move to the next course and reset within-course state."""
self.current_course_index += 1
self.stage = CanvasStage.PAGES
self.stage = "pages"
self.next_url = None
def advance_stage(self) -> None:
"""Advance past the current stage.
Moves to the next stage within the same course, or to the next
course if the current stage is the last one. Resets next_url so
the next call starts fresh on the new stage.
"""
self.next_url = None
stages: list[CanvasStage] = list(CanvasStage)
next_idx = stages.index(self.stage) + 1
if next_idx < len(stages):
self.stage = stages[next_idx]
else:
self.advance_course()
class CanvasConnector(
CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint],
@@ -359,7 +295,13 @@ class CanvasConnector(
if body_text:
text_parts.append(body_text)
doc_updated_at = _parse_canvas_dt(page.updated_at) if page.updated_at else None
doc_updated_at = (
datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone(
timezone.utc
)
if page.updated_at
else None
)
document = self._build_document(
doc_id=f"canvas-page-{page.course_id}-{page.page_id}",
@@ -383,11 +325,17 @@ class CanvasConnector(
if desc_text:
text_parts.append(desc_text)
if assignment.due_at:
due_dt = _parse_canvas_dt(assignment.due_at)
due_dt = datetime.fromisoformat(
assignment.due_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}")
doc_updated_at = (
_parse_canvas_dt(assignment.updated_at) if assignment.updated_at else None
datetime.fromisoformat(
assignment.updated_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if assignment.updated_at
else None
)
document = self._build_document(
@@ -413,7 +361,11 @@ class CanvasConnector(
text_parts.append(msg_text)
doc_updated_at = (
_parse_canvas_dt(announcement.posted_at) if announcement.posted_at else None
datetime.fromisoformat(
announcement.posted_at.replace("Z", "+00:00")
).astimezone(timezone.utc)
if announcement.posted_at
else None
)
document = self._build_document(
@@ -448,314 +400,6 @@ class CanvasConnector(
self._canvas_client = client
return None
def _fetch_stage_page(
self,
next_url: str | None,
endpoint: str,
params: dict[str, Any],
) -> tuple[list[Any], str | None]:
"""Fetch one page of API results for the current stage.
Returns (items, next_url). All error handling is done by the
caller (_load_from_checkpoint).
"""
if next_url:
# Resuming mid-pagination: the next_url from Canvas's
# Link header already contains endpoint + query params.
response, result_next_url = self.canvas_client.get(full_url=next_url)
else:
# First request for this stage: build from endpoint + params.
response, result_next_url = self.canvas_client.get(
endpoint=endpoint, params=params
)
return response or [], result_next_url
def _process_items(
self,
response: list[Any],
stage: CanvasStage,
course_id: int,
start: float,
end: float,
include_permissions: bool,
) -> tuple[list[Document | ConnectorFailure], bool]:
"""Process a page of API results into documents.
Returns (docs, early_exit). early_exit is True when pages
(sorted desc by updated_at) hit an item older than start,
signaling that pagination should stop.
"""
results: list[Document | ConnectorFailure] = []
early_exit = False
for item in response:
try:
if stage == CanvasStage.PAGES:
page = CanvasPage.from_api(item, course_id=course_id)
if not page.updated_at:
continue
# Pages are sorted by updated_at desc — once we see
# an item at or before `start`, all remaining items
# on this and subsequent pages are older too.
if not _in_time_window(page.updated_at, start, end):
if _parse_canvas_dt(page.updated_at).timestamp() <= start:
early_exit = True
break
# ts > end: page is newer than our window, skip it
continue
doc = self._convert_page_to_document(page)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ASSIGNMENTS:
assignment = CanvasAssignment.from_api(item, course_id=course_id)
if not assignment.updated_at or not _in_time_window(
assignment.updated_at, start, end
):
continue
doc = self._convert_assignment_to_document(assignment)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
elif stage == CanvasStage.ANNOUNCEMENTS:
announcement = CanvasAnnouncement.from_api(
item, course_id=course_id
)
if not announcement.posted_at:
logger.debug(
f"Skipping announcement {announcement.id} in "
f"course {course_id}: no posted_at"
)
continue
if not _in_time_window(announcement.posted_at, start, end):
continue
doc = self._convert_announcement_to_document(announcement)
results.append(
self._maybe_attach_permissions(
doc, course_id, include_permissions
)
)
except Exception as e:
item_id = item.get("id") or item.get("page_id", "unknown")
if stage == CanvasStage.PAGES:
doc_link = (
f"{self.canvas_base_url}/courses/{course_id}"
f"/pages/{item.get('url', '')}"
)
else:
doc_link = item.get("html_url", "")
results.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"canvas-{stage.removesuffix('s')}-{course_id}-{item_id}",
document_link=doc_link,
),
failure_message=f"Failed to process {stage.removesuffix('s')}: {e}",
exception=e,
)
)
return results, early_exit
def _maybe_attach_permissions(
self,
document: Document,
course_id: int,
include_permissions: bool,
) -> Document:
if include_permissions:
document.external_access = self._get_course_permissions(course_id)
return document
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
include_permissions: bool = False,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Shared implementation for load_from_checkpoint and load_from_checkpoint_with_perm_sync."""
new_checkpoint = checkpoint.model_copy(deep=True)
# First call: materialize the list of course IDs.
# On failure, let the exception propagate so the framework fails the
# attempt cleanly. Swallowing errors here would leave the checkpoint
# state unchanged and cause an infinite retry loop.
if not new_checkpoint.course_ids:
try:
courses = self._list_courses()
except OnyxError as e:
if e.status_code in (401, 403):
_handle_canvas_api_error(e) # NoReturn — always raises
raise
new_checkpoint.course_ids = [c.id for c in courses]
logger.info(f"Found {len(courses)} Canvas courses to process")
new_checkpoint.has_more = len(new_checkpoint.course_ids) > 0
return new_checkpoint
# All courses done.
if new_checkpoint.current_course_index >= len(new_checkpoint.course_ids):
new_checkpoint.has_more = False
return new_checkpoint
course_id = new_checkpoint.course_ids[new_checkpoint.current_course_index]
try:
stage = CanvasStage(new_checkpoint.stage)
except ValueError as e:
raise ValueError(
f"Invalid checkpoint stage: {new_checkpoint.stage!r}. "
f"Valid stages: {[s.value for s in CanvasStage]}"
) from e
# Build endpoint + params from the static template.
config = _STAGE_CONFIG[stage]
endpoint = config["endpoint"].format(course_id=course_id)
params = {k: v.format(course_id=course_id) for k, v in config["params"].items()}
# Only the announcements API supports server-side date filtering
# (start_date/end_date). Pages support server-side sorting
# (sort=updated_at desc) enabling early exit, but not date
# filtering. Assignments support neither. Both are filtered
# client-side via _in_time_window after fetching.
if stage == CanvasStage.ANNOUNCEMENTS:
params["start_date"] = _unix_to_canvas_time(start)
params["end_date"] = _unix_to_canvas_time(end)
try:
response, result_next_url = self._fetch_stage_page(
next_url=new_checkpoint.next_url,
endpoint=endpoint,
params=params,
)
except OnyxError as oe:
# Security errors from _parse_next_link (host/scheme
# mismatch on pagination URLs) have no status code override
# and must not be silenced.
is_api_error = oe._status_code_override is not None
if not is_api_error:
raise
if oe.status_code in (401, 403):
_handle_canvas_api_error(oe) # NoReturn — always raises
# 404 means the course itself is gone or inaccessible. The
# other stages on this course will hit the same 404, so skip
# the whole course rather than burning API calls on each stage.
if oe.status_code == 404:
logger.warning(
f"Canvas course {course_id} not found while fetching "
f"{stage} (HTTP 404). Skipping course."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-course-{course_id}",
),
failure_message=(f"Canvas course {course_id} not found: {oe}"),
exception=oe,
)
new_checkpoint.advance_course()
else:
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {oe}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {oe}"
),
exception=oe,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
except Exception as e:
# Unknown error — skip the stage and try to continue.
logger.warning(
f"Failed to fetch {stage} for course {course_id}: {e}. "
f"Skipping remainder of this stage."
)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=f"canvas-{stage}-{course_id}",
),
failure_message=(
f"Failed to fetch {stage} for course {course_id}: {e}"
),
exception=e,
)
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
# Process fetched items
results, early_exit = self._process_items(
response, stage, course_id, start, end, include_permissions
)
for result in results:
yield result
# If we hit an item older than our window (pages sorted desc),
# skip remaining pagination and advance to the next stage.
if early_exit:
result_next_url = None
# If there are more pages, save the cursor and return
if result_next_url:
new_checkpoint.next_url = result_next_url
else:
# Stage complete — advance to next stage (or next course if last).
new_checkpoint.advance_stage()
new_checkpoint.has_more = new_checkpoint.current_course_index < len(
new_checkpoint.course_ids
)
return new_checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
"""Load documents from checkpoint with permission information included."""
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint(has_more=True)
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
return CanvasConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def validate_connector_settings(self) -> None:
"""Validate Canvas connector settings by testing API access."""
@@ -771,6 +415,38 @@ class CanvasConnector(
f"Unexpected error during Canvas settings validation: {exc}"
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: CanvasConnectorCheckpoint,
) -> CheckpointOutput[CanvasConnectorCheckpoint]:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> CanvasConnectorCheckpoint:
# TODO(benwu408): implemented in PR3 (checkpoint)
raise NotImplementedError
@override
def retrieve_all_slim_docs_perm_sync(
self,

View File

@@ -95,13 +95,11 @@ class ClickupConnector(LoadConnector, PollConnector):
params["date_updated_lt"] = end
if self.connector_type == "list":
params["list_ids[]"] = self.connector_ids # ty: ignore[invalid-assignment]
params["list_ids[]"] = self.connector_ids
elif self.connector_type == "folder":
params["project_ids[]"] = ( # ty: ignore[invalid-assignment]
self.connector_ids
)
params["project_ids[]"] = self.connector_ids
elif self.connector_type == "space":
params["space_ids[]"] = self.connector_ids # ty: ignore[invalid-assignment]
params["space_ids[]"] = self.connector_ids
url_endpoint = f"/team/{self.team_id}/task"

View File

@@ -6,7 +6,7 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian.errors import ApiError
from atlassian.errors import ApiError # type: ignore
from requests.exceptions import HTTPError
from typing_extensions import override

View File

@@ -26,7 +26,7 @@ from typing import TypeVar
from urllib.parse import quote
import bs4
from atlassian import Confluence
from atlassian import Confluence # type:ignore
from redis import Redis
from requests import HTTPError
@@ -971,7 +971,7 @@ class OnyxConfluence:
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}

View File

@@ -165,7 +165,7 @@ class ConnectorRunner(Generic[CT]):
checkpoint_connector_generator = load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint, # ty: ignore[invalid-argument-type]
checkpoint=checkpoint,
)
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
@@ -174,9 +174,7 @@ class ConnectorRunner(Generic[CT]):
hierarchy_node,
failure,
next_checkpoint,
) in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator # ty: ignore[invalid-argument-type]
):
) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator):
if document is not None:
self.doc_batch.append(document)

View File

@@ -83,9 +83,7 @@ class OnyxDBCredentialsProvider(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = ( # ty: ignore[invalid-assignment]
credential_json
)
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
except Exception:
db_session.rollback()

View File

@@ -3,7 +3,6 @@ from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from email.utils import parsedate_to_datetime
from typing import Any
from typing import TypeVar
from urllib.parse import urljoin
@@ -11,6 +10,7 @@ from urllib.parse import urlparse
import requests
from dateutil.parser import parse
from dateutil.parser import ParserError
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import DocumentSource
@@ -56,16 +56,18 @@ def time_str_to_utc(datetime_str: str) -> datetime:
if fixed not in candidates:
candidates.append(fixed)
# dateutil is the primary; the stdlib RFC 2822 parser is a fallback for
# inputs dateutil rejects (e.g. headers concatenated without a CRLF —
# TZ may be dropped, datetime_to_utc then assumes UTC).
for parser in (parse, parsedate_to_datetime):
for candidate in candidates:
try:
return datetime_to_utc(parser(candidate))
except (TypeError, ValueError, OverflowError):
continue
last_exception: Exception | None = None
for candidate in candidates:
try:
dt = parse(candidate)
return datetime_to_utc(dt)
except (ValueError, ParserError) as exc:
last_exception = exc
if last_exception is not None:
raise last_exception
# Fallback in case parsing failed without raising (should not happen)
raise ValueError(f"Unable to parse datetime string: {datetime_str}")

View File

@@ -41,13 +41,9 @@ def tabular_file_to_sections(
"""
lowered = file_name.lower()
if lowered.endswith(tuple(OnyxFileExtensions.SPREADSHEET_EXTENSIONS)):
if lowered.endswith(".xlsx"):
return [
TabularSection(
link=link or file_name,
text=csv_text,
heading=f"{file_name} :: {sheet_title}",
)
TabularSection(link=f"sheet:{sheet_title}", text=csv_text)
for csv_text, sheet_title in xlsx_sheet_extraction(
file, file_name=file_name
)

View File

@@ -53,10 +53,8 @@ def _convert_message_to_document(
if isinstance(message.channel, TextChannel) and (
channel_name := message.channel.name
):
metadata["Channel"] = channel_name # ty: ignore[possibly-unresolved-reference]
semantic_substring += (
f" in Channel: #{channel_name}" # ty: ignore[possibly-unresolved-reference]
)
metadata["Channel"] = channel_name
semantic_substring += f" in Channel: #{channel_name}"
# Single messages dont have a title
title = ""

View File

@@ -221,8 +221,8 @@ class DiscourseConnector(PollConnector):
if self.permissions is None:
raise ConnectorMissingCredentialError("Discourse")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
self._get_categories_map()

View File

@@ -2,10 +2,10 @@ from datetime import timezone
from io import BytesIO
from typing import Any
from dropbox import Dropbox
from dropbox.exceptions import ApiError
from dropbox import Dropbox # type: ignore[import-untyped]
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
from dropbox.exceptions import AuthError
from dropbox.files import FileMetadata
from dropbox.files import FileMetadata # type: ignore[import-untyped]
from dropbox.files import FolderMetadata
from onyx.configs.app_configs import INDEX_BATCH_SIZE

View File

@@ -15,10 +15,6 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
@@ -37,7 +33,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -218,7 +213,7 @@ class DrupalWikiConnector(
attachment: dict[str, Any],
page_id: int,
download_url: str,
) -> tuple[list[TextSection | ImageSection | TabularSection], str | None]:
) -> tuple[list[TextSection | ImageSection], str | None]:
"""
Process a single attachment and return generated sections.
@@ -231,7 +226,7 @@ class DrupalWikiConnector(
Tuple of (sections, error_message). If error_message is not None, the
sections list should be treated as invalid.
"""
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
try:
if not self._validate_attachment_filetype(attachment):
@@ -278,25 +273,6 @@ class DrupalWikiConnector(
return sections, None
# Tabular attachments (xlsx, csv, tsv) — produce
# TabularSections instead of a flat TextSection.
if is_tabular_file(file_name):
try:
sections.extend(
tabular_file_to_sections(
BytesIO(raw_bytes),
file_name=file_name,
link=download_url,
)
)
except Exception:
logger.exception(
f"Failed to extract tabular sections from {file_name}"
)
if not sections:
return [], f"No content extracted from tabular file {file_name}"
return sections, None
image_counter = 0
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
@@ -521,7 +497,7 @@ class DrupalWikiConnector(
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
# Create sections with just the page content
sections: list[TextSection | ImageSection | TabularSection] = [
sections: list[TextSection | ImageSection] = [
TextSection(text=text_content, link=page_url)
]

View File

@@ -2,7 +2,6 @@ import json
import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import IO
@@ -13,16 +12,11 @@ from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
process_onyx_metadata,
)
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
@@ -185,32 +179,8 @@ def _process_file(
link = onyx_metadata.link or link
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection | TabularSection] = []
if is_tabular_file(file_name):
# Produce TabularSections
lowered_name = file_name.lower()
if lowered_name.endswith(tuple(OnyxFileExtensions.SPREADSHEET_EXTENSIONS)):
file.seek(0)
tabular_source: IO[bytes] = file
else:
tabular_source = BytesIO(
extraction_result.text_content.encode("utf-8", errors="replace")
)
try:
sections.extend(
tabular_file_to_sections(
file=tabular_source,
file_name=file_name,
link=link or "",
)
)
except Exception as e:
logger.error(f"Failed to process tabular file {file_name}: {e}")
return []
if not sections:
logger.warning(f"No content extracted from tabular file {file_name}")
return []
elif extraction_result.text_content.strip():
sections: list[TextSection | ImageSection] = []
if extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link}")
sections.append(
TextSection(link=link, text=extraction_result.text_content.strip())

View File

@@ -22,7 +22,6 @@ from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.constants import DocumentSource
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -36,16 +35,10 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -434,11 +427,7 @@ def make_cursor_url_callback(
return cursor_url_callback
class GithubConnector(
CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
):
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
@@ -570,7 +559,6 @@ class GithubConnector(
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
is_slim: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
@@ -626,46 +614,36 @@ class GithubConnector(
for pr in pr_batch:
num_prs += 1
if is_slim:
yield Document(
id=pr.html_url,
sections=[],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier="",
metadata={},
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
else:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
# If we reach this point with a cursor url in the checkpoint, we were using
# the fallback cursor-based pagination strategy. That strategy tries to get all
@@ -711,47 +689,38 @@ class GithubConnector(
for issue in issue_batch:
num_issues += 1
issue = cast(Issue, issue)
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
if is_slim:
yield Document(
id=issue.html_url,
sections=[],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier="",
metadata={},
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
else:
# we iterate backwards in time, so at this point we stop processing issues
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
done_with_issues = True
break
# Skip issues updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
try:
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
continue
logger.info(f"Fetched {num_issues} issues for repo: {repo.name}")
# if we found any issues on the page, and we're not done, return the checkpoint.
@@ -834,60 +803,6 @@ class GithubConnector(
start, end, checkpoint, include_permissions=True
)
def _retrieve_slim_docs(
self,
include_permissions: bool,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""Iterate all PRs and issues across all configured repos as SlimDocuments.
Drives _fetch_from_github in a checkpoint loop — each call processes one
page and returns an updated checkpoint. CheckpointOutputWrapper handles
draining the generator and extracting the returned checkpoint. Rate
limiting and pagination are handled centrally by _fetch_from_github via
_get_batch_rate_limited.
"""
checkpoint = self.build_dummy_checkpoint()
while checkpoint.has_more:
batch: list[SlimDocument | HierarchyNode] = []
gen = self._fetch_from_github(
checkpoint, include_permissions=include_permissions, is_slim=True
)
wrapper: CheckpointOutputWrapper[GithubConnectorCheckpoint] = (
CheckpointOutputWrapper()
)
for document, _, _, next_checkpoint in wrapper(gen):
if document is not None:
batch.append(
SlimDocument(
id=document.id, external_access=document.external_access
)
)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if batch:
yield batch
if callback and callback.should_stop():
raise RuntimeError("github_slim_docs: Stop signal detected")
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_slim_docs(include_permissions=False, callback=callback)
@override
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_slim_docs(include_permissions=True, callback=callback)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")

View File

@@ -7,7 +7,7 @@ from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from googleapiclient.errors import HttpError # type: ignore
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -253,17 +253,7 @@ def thread_to_document(
updated_at_datetime = None
if updated_at:
try:
updated_at_datetime = time_str_to_utc(updated_at)
except (ValueError, OverflowError) as e:
# Old mailboxes contain RFC-violating Date headers. Drop the
# timestamp instead of aborting the indexing run.
logger.warning(
"Skipping unparseable Gmail Date header on thread %s: %r (%s)",
full_thread.get("id"),
updated_at,
e,
)
updated_at_datetime = time_str_to_utc(updated_at)
id = full_thread.get("id")
if not id:
@@ -306,9 +296,7 @@ def _full_thread_from_id(
try:
thread = next(
execute_single_retrieval(
retrieval_function=gmail_service.users() # ty: ignore[unresolved-attribute]
.threads()
.get,
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
@@ -406,7 +394,7 @@ class GmailConnector(
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list, # ty: ignore[unresolved-attribute]
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
@@ -450,9 +438,7 @@ class GmailConnector(
try:
for thread in execute_paginated_retrieval_with_max_pages(
max_num_pages=PAGES_PER_CHECKPOINT,
retrieval_function=gmail_service.users() # ty: ignore[unresolved-attribute]
.threads()
.list,
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,

View File

@@ -1,5 +1,4 @@
import base64
import copy
import time
from collections.abc import Generator
from datetime import datetime
@@ -9,58 +8,27 @@ from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import GONG_CONNECTOR_START_TIME
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GongConnectorCheckpoint(ConnectorCheckpoint):
# Resolved workspace IDs to iterate through.
# None means "not yet resolved" — first checkpoint call resolves them.
# Inner None means "no workspace filter" (fetch all).
workspace_ids: list[str | None] | None = None
# Index into workspace_ids for current workspace
workspace_index: int = 0
# Gong API cursor for current workspace's transcript pagination
cursor: str | None = None
# Cached time range — computed once, reused across checkpoint calls
time_range: tuple[str, str] | None = None
class _TranscriptPage(BaseModel):
"""One page of transcripts from /v2/calls/transcript."""
transcripts: list[dict[str, Any]]
next_cursor: str | None = None
class _CursorExpiredError(Exception):
"""Raised when Gong rejects a pagination cursor as expired.
Gong pagination cursors TTL is ~1 hour from the first request in a
pagination sequence, not from the last cursor fetch. Since checkpointed
connector runs can pause between invocations, a resumed run may encounter
an expired cursor and must restart the current workspace from scratch.
See https://visioneers.gong.io/integrations-77/pagination-cursor-expires-after-1-hours-even-for-a-new-cursor-1382
"""
class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
@@ -70,9 +38,13 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
def __init__(
self,
workspaces: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_fail: bool = CONTINUE_ON_CONNECTOR_FAILURE,
hide_user_info: bool = False,
) -> None:
self.workspaces = workspaces
self.batch_size: int = batch_size
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
@@ -126,50 +98,67 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
# Then the user input is treated as the name
return {**id_id_map, **name_id_map}
def _fetch_transcript_page(
self,
start_datetime: str | None,
end_datetime: str | None,
workspace_id: str | None,
cursor: str | None,
) -> _TranscriptPage:
"""Fetch one page of transcripts from the Gong API.
Raises _CursorExpiredError if Gong reports the pagination cursor
expired (TTL is ~1 hour from first request in the pagination sequence).
"""
body: dict[str, Any] = {"filter": {}}
def _get_transcript_batches(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Generator[list[dict[str, Any]], None, None]:
body: dict[str, dict] = {"filter": {}}
if start_datetime:
body["filter"]["fromDateTime"] = start_datetime
if end_datetime:
body["filter"]["toDateTime"] = end_datetime
if workspace_id:
body["filter"]["workspaceId"] = workspace_id
if cursor:
body["cursor"] = cursor
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, return empty
if response.status_code == 404:
return _TranscriptPage(transcripts=[])
# The batch_ids in the previous method appears to be batches of call_ids to process
# In this method, we will retrieve transcripts for them in batches.
transcripts: list[dict[str, Any]] = []
workspace_list = self.workspaces or [None] # type: ignore
workspace_map = self._get_workspace_id_map() if self.workspaces else {}
if not response.ok:
# Cursor expiration comes back as a 4xx with this error message —
# detect it before raise_for_status so callers can restart the workspace.
if cursor and "cursor has expired" in response.text.lower():
raise _CursorExpiredError(response.text)
logger.error(f"Error fetching transcripts: {response.text}")
response.raise_for_status()
for workspace in workspace_list:
if workspace:
logger.info(f"Updating Gong workspace: {workspace}")
workspace_id = workspace_map.get(workspace)
if not workspace_id:
logger.error(f"Invalid Gong workspace: {workspace}")
if not self.continue_on_fail:
raise ValueError(f"Invalid workspace: {workspace}")
continue
body["filter"]["workspaceId"] = workspace_id
else:
if "workspaceId" in body["filter"]:
del body["filter"]["workspaceId"]
data = response.json()
return _TranscriptPage(
transcripts=data.get("callTranscripts", []),
next_cursor=data.get("records", {}).get("cursor"),
)
while True:
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
break
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict[str, Any]:
try:
response.raise_for_status()
except Exception:
logger.error(f"Error fetching transcripts: {response.text}")
raise
data = response.json()
call_transcripts = data.get("callTranscripts", [])
transcripts.extend(call_transcripts)
while len(transcripts) >= self.batch_size:
yield transcripts[: self.batch_size]
transcripts = transcripts[self.batch_size :]
cursor = data.get("records", {}).get("cursor")
if cursor:
body["cursor"] = cursor
else:
break
if transcripts:
yield transcripts
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict:
body = {
"filter": {"callIds": call_ids},
"contentSelector": {"exposedFields": {"parties": True}},
@@ -187,50 +176,6 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
return call_to_metadata
def _fetch_call_details_with_retry(self, call_ids: list[str]) -> dict[str, Any]:
"""Fetch call details with retry for the Gong API race condition.
The Gong API has a known race where transcript call IDs don't immediately
appear in /v2/calls/extensive. Retries with exponential backoff, only
re-requesting the missing IDs on each attempt.
"""
call_details_map = self._get_call_details_by_ids(call_ids)
if set(call_ids) == set(call_details_map.keys()):
return call_details_map
for attempt in range(2, self.MAX_CALL_DETAILS_ATTEMPTS + 1):
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
logger.warning(
f"_get_call_details_by_ids is missing call id's: current_attempt={attempt - 1} missing_call_ids={missing_ids}"
)
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, attempt - 2)
logger.warning(
f"_get_call_details_by_ids waiting to retry: "
f"wait={wait_seconds}s "
f"current_attempt={attempt - 1} "
f"next_attempt={attempt} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
time.sleep(wait_seconds)
# Only re-fetch the missing IDs, merge into existing results
new_details = self._get_call_details_by_ids(missing_ids)
call_details_map.update(new_details)
if set(call_ids) == set(call_details_map.keys()):
return call_details_map
missing_ids = list(set(call_ids) - set(call_details_map.keys()))
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(call_ids)} calls"
)
return call_details_map
@staticmethod
def _parse_parties(parties: list[dict]) -> dict[str, str]:
id_mapping = {}
@@ -251,46 +196,186 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
return id_mapping
def _resolve_workspace_ids(self) -> list[str | None]:
"""Resolve configured workspace names/IDs to actual workspace IDs.
def _fetch_calls(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> GenerateDocumentsOutput:
num_calls = 0
Returns a list of workspace IDs. If no workspaces are configured,
returns [None] to indicate "fetch all workspaces".
for transcript_batch in self._get_transcript_batches(
start_datetime, end_datetime
):
doc_batch: list[Document | HierarchyNode] = []
Raises ValueError if workspaces are configured but none resolve —
we never silently widen scope to "fetch all" on misconfiguration,
because that could ingest an entire Gong account by mistake.
"""
if not self.workspaces:
return [None]
workspace_map = self._get_workspace_id_map()
resolved: list[str | None] = []
for workspace in self.workspaces:
workspace_id = workspace_map.get(workspace)
if not workspace_id:
logger.error(f"Invalid Gong workspace: {workspace}")
continue
resolved.append(workspace_id)
if not resolved:
raise ValueError(
f"No valid Gong workspaces found — check workspace names/IDs in connector config. Configured: {self.workspaces}"
transcript_call_ids = cast(
list[str],
[t.get("callId") for t in transcript_batch if t.get("callId")],
)
return resolved
call_details_map: dict[str, Any] = {}
@staticmethod
def _compute_time_range(
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> tuple[str, str]:
"""Compute the start/end datetime strings for the Gong API filter,
applying GONG_CONNECTOR_START_TIME and the 1-day offset."""
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
call_details_map = self._get_call_details_by_ids(transcript_call_ids)
if set(transcript_call_ids) == set(call_details_map.keys()):
# we got all the id's we were expecting ... break and continue
break
# we are missing some id's. Log and retry with exponential backoff
missing_call_ids = set(transcript_call_ids) - set(
call_details_map.keys()
)
logger.warning(
f"_get_call_details_by_ids is missing call id's: "
f"current_attempt={current_attempt} "
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(
f"_get_call_details_by_ids waiting to retry: "
f"wait={wait_seconds}s "
f"current_attempt={current_attempt} "
f"next_attempt={current_attempt + 1} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
time.sleep(wait_seconds)
# now we can iterate per call/transcript
for transcript in transcript_batch:
call_id = transcript.get("callId")
if not call_id or call_id not in call_details_map:
# NOTE(rkuo): seeing odd behavior where call_ids from the transcript
# don't have call details. adding error debugging logs to trace.
logger.error(
f"Couldn't get call information for Call ID: {call_id}"
)
if call_id:
logger.error(
f"Call debug info: call_id={call_id} "
f"call_ids={transcript_call_ids} "
f"call_details_map={call_details_map.keys()}"
)
if not self.continue_on_fail:
raise RuntimeError(
f"Couldn't get call information for Call ID: {call_id}"
)
continue
call_details = call_details_map[call_id]
call_metadata = call_details["metaData"]
call_time_str = call_metadata["started"]
call_title = call_metadata["title"]
logger.info(
f"{num_calls + 1}: Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
)
call_parties = cast(list[dict] | None, call_details.get("parties"))
if call_parties is None:
logger.error(f"Couldn't get parties for Call ID: {call_id}")
call_parties = []
id_to_name_map = self._parse_parties(call_parties)
# Keeping a separate dict here in case the parties info is incomplete
speaker_to_name: dict[str, str] = {}
transcript_text = ""
call_purpose = call_metadata["purpose"]
if call_purpose:
transcript_text += f"Call Description: {call_purpose}\n\n"
contents = transcript["transcript"]
for segment in contents:
speaker_id = segment.get("speakerId", "")
if speaker_id not in speaker_to_name:
if self.hide_user_info:
speaker_to_name[speaker_id] = (
f"User {len(speaker_to_name) + 1}"
)
else:
speaker_to_name[speaker_id] = id_to_name_map.get(
speaker_id, "Unknown"
)
speaker_name = speaker_to_name[speaker_id]
sentences = segment.get("sentences", {})
monolog = " ".join(
[sentence.get("text", "") for sentence in sentences]
)
transcript_text += f"{speaker_name}: {monolog}\n\n"
metadata = {}
if call_metadata.get("system"):
metadata["client"] = call_metadata.get("system")
# TODO calls have a clientUniqueId field, can pull that in later
doc_batch.append(
Document(
id=call_id,
sections=[
TextSection(link=call_metadata["url"], text=transcript_text)
],
source=DocumentSource.GONG,
# Should not ever be Untitled as a call cannot be made without a Title
semantic_identifier=call_title or "Untitled",
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
timezone.utc
),
metadata={"client": call_metadata.get("system")},
)
)
num_calls += 1
yield doc_batch
logger.info(f"_fetch_calls finished: num_calls={num_calls}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
combined = (
f"{credentials['gong_access_key']}:{credentials['gong_access_key_secret']}"
)
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
"utf-8"
)
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
self._session.headers.update(
{"Authorization": f"Basic {self.auth_token_basic}"}
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_calls()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# if this env variable is set, don't start from a timestamp before the specified
# start time
# TODO: remove this once this is globally available
if GONG_CONNECTOR_START_TIME:
special_start_datetime = datetime.fromisoformat(GONG_CONNECTOR_START_TIME)
special_start_datetime = special_start_datetime.replace(tzinfo=timezone.utc)
@@ -309,186 +394,11 @@ class GongConnector(CheckpointedConnector[GongConnectorCheckpoint]):
# so adding a 1 day buffer and fetching by default till current time
start_one_day_offset = start_datetime - timedelta(days=1)
start_time = start_one_day_offset.isoformat()
end_time = end_datetime.isoformat()
return start_time, end_time
end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
def _process_transcripts(
self,
transcripts: list[dict[str, Any]],
) -> Generator[Document | ConnectorFailure, None, None]:
"""Process a batch of transcripts into Documents or ConnectorFailures."""
transcript_call_ids = cast(
list[str],
[t.get("callId") for t in transcripts if t.get("callId")],
)
call_details_map = self._fetch_call_details_with_retry(transcript_call_ids)
for transcript in transcripts:
call_id = transcript.get("callId")
if not call_id or call_id not in call_details_map:
logger.error(f"Couldn't get call information for Call ID: {call_id}")
if call_id:
logger.error(
f"Call debug info: call_id={call_id} "
f"call_ids={transcript_call_ids} "
f"call_details_map={call_details_map.keys()}"
)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=call_id or "unknown",
),
failure_message=f"Couldn't get call information for Call ID: {call_id}",
)
continue
call_details = call_details_map[call_id]
call_metadata = call_details["metaData"]
call_time_str = call_metadata["started"]
call_title = call_metadata["title"]
logger.info(
f"Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}"
)
call_parties = cast(list[dict] | None, call_details.get("parties"))
if call_parties is None:
logger.error(f"Couldn't get parties for Call ID: {call_id}")
call_parties = []
id_to_name_map = self._parse_parties(call_parties)
speaker_to_name: dict[str, str] = {}
transcript_text = ""
call_purpose = call_metadata["purpose"]
if call_purpose:
transcript_text += f"Call Description: {call_purpose}\n\n"
contents = transcript["transcript"]
for segment in contents:
speaker_id = segment.get("speakerId", "")
if speaker_id not in speaker_to_name:
if self.hide_user_info:
speaker_to_name[speaker_id] = f"User {len(speaker_to_name) + 1}"
else:
speaker_to_name[speaker_id] = id_to_name_map.get(
speaker_id, "Unknown"
)
speaker_name = speaker_to_name[speaker_id]
sentences = segment.get("sentences", {})
monolog = " ".join([sentence.get("text", "") for sentence in sentences])
transcript_text += f"{speaker_name}: {monolog}\n\n"
yield Document(
id=call_id,
sections=[TextSection(link=call_metadata["url"], text=transcript_text)],
source=DocumentSource.GONG,
semantic_identifier=call_title or "Untitled",
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
timezone.utc
),
metadata={"client": call_metadata.get("system")},
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
combined = (
f"{credentials['gong_access_key']}:{credentials['gong_access_key_secret']}"
)
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
"utf-8"
)
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
self._session.headers.update(
{"Authorization": f"Basic {self.auth_token_basic}"}
)
return None
def build_dummy_checkpoint(self) -> GongConnectorCheckpoint:
return GongConnectorCheckpoint(has_more=True)
def validate_checkpoint_json(self, checkpoint_json: str) -> GongConnectorCheckpoint:
return GongConnectorCheckpoint.model_validate_json(checkpoint_json)
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GongConnectorCheckpoint,
) -> CheckpointOutput[GongConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
# Step 1: Resolve workspace IDs on first call
if checkpoint.workspace_ids is None:
checkpoint.workspace_ids = self._resolve_workspace_ids()
checkpoint.time_range = self._compute_time_range(start, end)
checkpoint.has_more = True
return checkpoint
workspace_ids = checkpoint.workspace_ids
# If we've exhausted all workspaces, we're done
if checkpoint.workspace_index >= len(workspace_ids):
checkpoint.has_more = False
return checkpoint
# Use cached time range, falling back to computation if not cached
start_time, end_time = checkpoint.time_range or self._compute_time_range(
start, end
)
logger.info(
f"Fetching Gong calls between {start_time} and {end_time} "
f"(workspace {checkpoint.workspace_index + 1}/{len(workspace_ids)})"
)
workspace_id = workspace_ids[checkpoint.workspace_index]
# Step 2: Fetch one page of transcripts
try:
page = self._fetch_transcript_page(
start_datetime=start_time,
end_datetime=end_time,
workspace_id=workspace_id,
cursor=checkpoint.cursor,
)
except _CursorExpiredError:
# Gong cursors TTL ~1h from first request in the sequence. If the
# checkpoint paused long enough for the cursor to expire, restart
# the current workspace from the beginning of the time range.
# Document upserts are idempotent (keyed by call_id) so
# reprocessing is safe.
logger.warning(
f"Gong pagination cursor expired for workspace "
f"{checkpoint.workspace_index + 1}/{len(workspace_ids)}; "
f"restarting workspace from beginning of time range."
)
checkpoint.cursor = None
checkpoint.has_more = True
return checkpoint
# Step 3: Process transcripts into documents
if page.transcripts:
yield from self._process_transcripts(page.transcripts)
# Step 4: Update checkpoint state
if page.next_cursor:
# More pages in this workspace
checkpoint.cursor = page.next_cursor
checkpoint.has_more = True
else:
# This workspace is exhausted — advance to next
checkpoint.workspace_index += 1
checkpoint.cursor = None
checkpoint.has_more = checkpoint.workspace_index < len(workspace_ids)
return checkpoint
logger.info(f"Fetching Gong calls between {start_time} and {end_time}")
return self._fetch_calls(start_time, end_time)
if __name__ == "__main__":
@@ -502,13 +412,5 @@ if __name__ == "__main__":
}
)
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_generator = connector.load_from_checkpoint(0, time.time(), checkpoint)
try:
while True:
item = next(doc_generator)
print(item)
except StopIteration as e:
checkpoint = e.value
print(f"Checkpoint: {checkpoint}")
latest_docs = connector.load_from_state()
print(next(latest_docs))

View File

@@ -18,7 +18,7 @@ from urllib.parse import urlunparse
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
from onyx.access.models import ExternalAccess
@@ -434,7 +434,7 @@ class GoogleDriveConnector(
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list, # ty: ignore[unresolved-attribute]
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
@@ -502,9 +502,6 @@ class GoogleDriveConnector(
files: list[RetrievedDriveFile],
seen_hierarchy_node_raw_ids: ThreadSafeSet[str],
fully_walked_hierarchy_node_raw_ids: ThreadSafeSet[str],
failed_folder_ids_by_email: (
ThreadSafeDict[str, ThreadSafeSet[str]] | None
) = None,
permission_sync_context: PermissionSyncContext | None = None,
add_prefix: bool = False,
) -> list[HierarchyNode]:
@@ -528,9 +525,6 @@ class GoogleDriveConnector(
seen_hierarchy_node_raw_ids: Set of already-yielded node IDs (modified in place)
fully_walked_hierarchy_node_raw_ids: Set of node IDs where the walk to root
succeeded (modified in place)
failed_folder_ids_by_email: Map of email → folder IDs where that email
previously confirmed no accessible parent. Skips the API call if the same
(folder, email) is encountered again (modified in place).
permission_sync_context: If provided, permissions will be fetched for hierarchy nodes.
Contains google_domain and primary_admin_email needed for permission syncing.
add_prefix: When True, prefix group IDs with source type (for indexing path).
@@ -575,7 +569,7 @@ class GoogleDriveConnector(
# Fetch folder metadata
folder = self._get_folder_metadata(
current_id, file.user_email, field_type, failed_folder_ids_by_email
current_id, file.user_email, field_type
)
if not folder:
# Can't access this folder - stop climbing
@@ -659,13 +653,7 @@ class GoogleDriveConnector(
return new_nodes
def _get_folder_metadata(
self,
folder_id: str,
retriever_email: str,
field_type: DriveFileFieldType,
failed_folder_ids_by_email: (
ThreadSafeDict[str, ThreadSafeSet[str]] | None
) = None,
self, folder_id: str, retriever_email: str, field_type: DriveFileFieldType
) -> GoogleDriveFileType | None:
"""
Fetch metadata for a folder by ID.
@@ -679,17 +667,6 @@ class GoogleDriveConnector(
# Use a set to deduplicate if retriever_email == primary_admin_email
for email in {retriever_email, self.primary_admin_email}:
failed_ids = (
failed_folder_ids_by_email.get(email)
if failed_folder_ids_by_email
else None
)
if failed_ids and folder_id in failed_ids:
logger.debug(
f"Skipping folder {folder_id} using {email} (previously confirmed no parents)"
)
continue
service = get_drive_service(self.creds, email)
folder = get_folder_metadata(service, folder_id, field_type)
@@ -705,10 +682,6 @@ class GoogleDriveConnector(
# Folder has no parents - could be a root OR user lacks access to parent
# Keep this as a fallback but try admin to see if they can see parents
if failed_folder_ids_by_email is not None:
failed_folder_ids_by_email.setdefault(email, ThreadSafeSet()).add(
folder_id
)
if best_folder is None:
best_folder = folder
logger.debug(
@@ -746,7 +719,7 @@ class GoogleDriveConnector(
)
all_drive_ids: set[str] = set()
for drive in execute_paginated_retrieval(
retrieval_function=drive_service.drives().list, # ty: ignore[unresolved-attribute]
retrieval_function=drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=is_service_account,
fields="drives(id),nextPageToken",
@@ -934,9 +907,7 @@ class GoogleDriveConnector(
# resume from a checkpoint
if resuming and (drive_id := curr_stage.current_folder_or_drive_id):
resume_start = curr_stage.completed_until
for file_or_token in _yield_from_drive(
drive_id, resume_start # ty: ignore[possibly-unresolved-reference]
):
for file_or_token in _yield_from_drive(drive_id, resume_start):
if isinstance(file_or_token, str):
checkpoint.completion_map[user_email].next_page_token = (
file_or_token
@@ -1117,13 +1088,6 @@ class GoogleDriveConnector(
]
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
# Free per-user cache entries now that this batch is done.
# Skip the admin email — it is shared across all user batches and must
# persist for the duration of the run.
for email in non_completed_org_emails:
if email != self.primary_admin_email:
checkpoint.failed_folder_ids_by_email.pop(email, None)
# if there are more emails to process, don't mark as complete
if not email_batch_takes_us_to_completion:
return
@@ -1338,9 +1302,7 @@ class GoogleDriveConnector(
resume_start = checkpoint.completion_map[
self.primary_admin_email
].completed_until
yield from _yield_from_folder_crawl(
folder_id, resume_start # ty: ignore[possibly-unresolved-reference]
)
yield from _yield_from_folder_crawl(folder_id, resume_start)
# the times stored in the completion_map aren't used due to the crawling behavior
# instead, the traversed_parent_ids are used to determine what we have left to retrieve
@@ -1580,7 +1542,6 @@ class GoogleDriveConnector(
files=files_batch,
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
failed_folder_ids_by_email=checkpoint.failed_folder_ids_by_email,
permission_sync_context=permission_sync_context,
add_prefix=True,
)
@@ -1817,7 +1778,6 @@ class GoogleDriveConnector(
files=files_batch,
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
failed_folder_ids_by_email=checkpoint.failed_folder_ids_by_email,
permission_sync_context=permission_sync_context,
)
@@ -1923,9 +1883,7 @@ class GoogleDriveConnector(
try:
drive_service = get_drive_service(self._creds, self._primary_admin_email)
drive_service.files().list( # ty: ignore[unresolved-attribute]
pageSize=1, fields="files(id)"
).execute()
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
# default is ~17mins of retries, don't do that here since this is called from

View File

@@ -6,17 +6,13 @@ from typing import cast
from urllib.parse import urlparse
from urllib.parse import urlunparse
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
from pydantic import BaseModel
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.tabular_section_utils import is_tabular_file
from onyx.connectors.cross_connector_utils.tabular_section_utils import (
tabular_file_to_sections,
)
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import GDriveMimeType
@@ -32,16 +28,15 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TabularSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_docx_file
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.file_types import SPREADSHEET_MIME_TYPE
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
@@ -65,7 +60,7 @@ def _get_folder_info(
try:
folder = (
service.files() # ty: ignore[unresolved-attribute]
service.files()
.get(
fileId=folder_id,
fields="name, parents",
@@ -91,11 +86,7 @@ def _get_drive_name(service: GoogleDriveService, drive_id: str) -> str:
return _folder_cache[cache_key][0]
try:
drive = (
service.drives() # ty: ignore[unresolved-attribute]
.get(driveId=drive_id)
.execute()
)
drive = service.drives().get(driveId=drive_id).execute()
drive_name = drive.get("name", f"Shared Drive {drive_id}")
_folder_cache[cache_key] = (drive_name, None)
return drive_name
@@ -262,9 +253,7 @@ def download_request(
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media( # ty: ignore[unresolved-attribute]
fileId=file_id
)
request = service.files().get_media(fileId=file_id)
return _download_request(request, file_id, size_threshold)
@@ -300,7 +289,7 @@ def _download_and_extract_sections_basic(
service: GoogleDriveService,
allow_images: bool,
size_threshold: int,
) -> list[TextSection | ImageSection | TabularSection]:
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
@@ -319,7 +308,7 @@ def _download_and_extract_sections_basic(
return []
# Store images for later processing
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
@@ -334,10 +323,11 @@ def _download_and_extract_sections_basic(
logger.error(f"Failed to process image {file_name}: {e}")
return sections
# For Google Docs, Sheets, and Slides, export via the Drive API
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
request = service.files().export_media( # ty: ignore[unresolved-attribute]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response = _download_request(request, file_id, size_threshold)
@@ -345,17 +335,6 @@ def _download_and_extract_sections_basic(
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
if export_mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
# Synthesize an extension on the filename
ext = ".xlsx" if export_mime_type == SPREADSHEET_MIME_TYPE else ".csv"
return list(
tabular_file_to_sections(
io.BytesIO(response),
file_name=f"{file_name}{ext}",
link=link,
)
)
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
@@ -377,15 +356,9 @@ def _download_and_extract_sections_basic(
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
or is_tabular_file(file_name)
):
return list(
tabular_file_to_sections(
io.BytesIO(response_call()),
file_name=file_name,
link=link,
)
)
text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif (
mime_type
@@ -396,7 +369,7 @@ def _download_and_extract_sections_basic(
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection | TabularSection] = [
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
@@ -437,9 +410,8 @@ def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
def align_basic_advanced(
basic_sections: list[TextSection | ImageSection | TabularSection],
adv_sections: list[TextSection],
) -> list[TextSection | ImageSection | TabularSection]:
basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection]
) -> list[TextSection | ImageSection]:
"""Align the basic sections with the advanced sections.
In particular, the basic sections contain all content of the file,
including smart chips like dates and doc links. The advanced sections
@@ -456,14 +428,12 @@ def align_basic_advanced(
basic_full_text = "".join(
[section.text for section in basic_sections if isinstance(section, TextSection)]
)
new_sections: list[TextSection | ImageSection | TabularSection] = []
new_sections: list[TextSection | ImageSection] = []
heading_start = 0
for adv_ind in range(1, len(adv_sections)):
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
# retrieve the longest part of the heading that is not a smart chip
heading_key = max( # ty: ignore[unresolved-attribute]
heading.split(SMART_CHIP_CHAR), key=len
).strip()
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
if heading_key == "":
logger.warning(
f"Cannot match heading: {heading}, its link will come from the following section"
@@ -629,7 +599,7 @@ def _convert_drive_item_to_document(
"""
Main entry point for converting a Google Drive file => Document object.
"""
sections: list[TextSection | ImageSection | TabularSection] = []
sections: list[TextSection | ImageSection] = []
# Only construct these services when needed
def _get_drive_service() -> GoogleDriveService:
@@ -669,9 +639,7 @@ def _convert_drive_item_to_document(
doc_id=file.get("id", ""),
)
if doc_sections:
sections = cast(
list[TextSection | ImageSection | TabularSection], doc_sections
)
sections = cast(list[TextSection | ImageSection], doc_sections)
if any(SMART_CHIP_CHAR in section.text for section in doc_sections):
logger.debug(
f"found smart chips in {file.get('name')}, aligning with basic sections"

View File

@@ -7,9 +7,9 @@ from typing import cast
from urllib.parse import parse_qs
from urllib.parse import urlparse
from googleapiclient.discovery import Resource
from googleapiclient.errors import HttpError
from googleapiclient.http import BatchHttpRequest
from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import BatchHttpRequest # type: ignore
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
@@ -115,7 +115,7 @@ def _get_folders_in_parent(
query += f" and '{parent_id}' in parents"
for file in execute_paginated_retrieval(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
@@ -136,7 +136,7 @@ def get_folder_metadata(
fields = _get_hierarchy_fields_for_file_type(field_type)
try:
return (
service.files() # ty: ignore[unresolved-attribute]
service.files()
.get(
fileId=folder_id,
fields=fields,
@@ -169,11 +169,7 @@ def get_shared_drive_name(
folders. Only drives().get() returns the real user-assigned name.
"""
try:
drive = (
service.drives() # ty: ignore[unresolved-attribute]
.get(driveId=drive_id, fields="name")
.execute()
)
drive = service.drives().get(driveId=drive_id, fields="name").execute()
return drive.get("name")
except HttpError as e:
if e.resp.status in (403, 404):
@@ -265,7 +261,7 @@ def _get_files_in_parent(
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
for file in execute_paginated_retrieval(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="allDrives",
@@ -375,7 +371,7 @@ def get_files_in_shared_drive(
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
corpora="drive",
@@ -393,7 +389,7 @@ def get_files_in_shared_drive(
file_query += generate_time_range_filter(start, end)
for file in execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=True,
@@ -440,7 +436,7 @@ def get_all_files_in_my_drive_and_shared(
folder_query += " and 'me' in owners"
found_folders = False
for folder in execute_paginated_retrieval(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=_get_fields_for_file_type(field_type),
@@ -458,7 +454,7 @@ def get_all_files_in_my_drive_and_shared(
file_query += " and 'me' in owners"
file_query += generate_time_range_filter(start, end)
yield from execute_paginated_retrieval_with_max_pages(
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
max_num_pages=max_num_pages,
list_key="files",
continue_on_404_or_403=False,
@@ -503,7 +499,7 @@ def get_all_files_for_oauth(
yield from execute_paginated_retrieval_with_max_pages(
max_num_pages=max_num_pages,
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
@@ -520,7 +516,7 @@ def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return (
service.files() # ty: ignore[unresolved-attribute]
service.files()
.get(fileId="root", fields=GoogleFields.ID.value)
.execute()[GoogleFields.ID.value]
)
@@ -554,7 +550,7 @@ def get_file_by_web_view_link(
"""Retrieve a Google Drive file using its webViewLink."""
file_id = _extract_file_id_from_web_view_link(web_view_link)
return (
service.files() # ty: ignore[unresolved-attribute]
service.files()
.get(
fileId=file_id,
supportsAllDrives=True,
@@ -616,17 +612,12 @@ def _get_files_by_web_view_links_batch(
else:
result.files[request_id] = response
batch = cast(
BatchHttpRequest,
service.new_batch_http_request( # ty: ignore[unresolved-attribute]
callback=callback
),
)
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
for web_view_link in web_view_links:
try:
file_id = _extract_file_id_from_web_view_link(web_view_link)
request = service.files().get( # ty: ignore[unresolved-attribute]
request = service.files().get(
fileId=file_id,
supportsAllDrives=True,
fields=fields,

Some files were not shown because too many files have changed in this diff Show More