mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-13 11:42:40 +00:00
Compare commits
2 Commits
bo/custom_
...
v2.0.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b36910240d | ||
|
|
488b27ba04 |
@@ -8,9 +8,9 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -33,7 +33,16 @@ jobs:
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -46,7 +55,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -119,7 +129,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
@@ -11,8 +11,8 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -145,6 +145,15 @@ jobs:
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -157,11 +166,16 @@ jobs:
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@v3
|
||||
|
||||
@@ -7,7 +7,10 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
jobs:
|
||||
@@ -45,6 +48,15 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +69,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -126,7 +139,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
6
.github/workflows/helm-chart-releases.yml
vendored
6
.github/workflows/helm-chart-releases.yml
vendored
@@ -25,9 +25,11 @@ jobs:
|
||||
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
@@ -20,6 +20,7 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
71
.github/workflows/pr-helm-chart-testing.yml
vendored
71
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -65,35 +65,45 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -149,6 +159,7 @@ jobs:
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
@@ -156,8 +167,10 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
@@ -173,8 +186,16 @@ jobs:
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
@@ -199,7 +220,7 @@ jobs:
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -22,9 +22,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -131,6 +133,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -158,6 +161,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -191,6 +195,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -337,9 +342,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -19,9 +19,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -128,6 +130,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -155,6 +158,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -188,6 +192,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -334,9 +339,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -20,11 +20,13 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
|
||||
@@ -34,8 +34,6 @@ repos:
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
@@ -105,6 +101,11 @@ pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
@@ -117,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `onyx/web` run:
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22`
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
@@ -129,8 +137,6 @@ npm i
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
@@ -150,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
We highly recommend using VSCode debugger for development.
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
|
||||
Only run this if you are okay with wiping your data.
|
||||
|
||||
## Features
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add image input support to model config
|
||||
|
||||
Revision ID: 64bd5677aeb6
|
||||
Revises: b30353be4eec
|
||||
Create Date: 2025-09-28 15:48:12.003612
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "64bd5677aeb6"
|
||||
down_revision = "b30353be4eec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
# Seems to be left over from when model visibility was introduced and a nullable field.
|
||||
# Set any null is_visible values to False
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "supports_image_input")
|
||||
@@ -124,9 +124,9 @@ def get_space_permission(
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely "
|
||||
"to be correct and is more likely caused by an access token with "
|
||||
"insufficient permissions. Make sure that the access token has Admin "
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
return gmail_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
|
||||
for raw_perm in permissions:
|
||||
if not hasattr(raw_perm, "raw"):
|
||||
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
# In order to associate this permission to some Atlassian entity, we need the "Holder".
|
||||
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
|
||||
if not permission.holder:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find a permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
type = permission.holder.get("type")
|
||||
if not type:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find the type of permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -105,7 +105,9 @@ def _get_slack_document_access(
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
|
||||
@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ def generic_doc_sync(
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
slim_connector: SlimConnectorWithPermSync,
|
||||
label: str,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -40,7 +40,7 @@ def generic_doc_sync(
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
@@ -37,6 +37,19 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Azure AD / Entra ID often returns the email attribute under different keys.
|
||||
# Keep a list of common variations so we can fall back gracefully if the IdP
|
||||
# does not send the plain "email" attribute name.
|
||||
EMAIL_ATTRIBUTE_KEYS = {
|
||||
"email",
|
||||
"emailaddress",
|
||||
"mail",
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/mail",
|
||||
"http://schemas.microsoft.com/identity/claims/emailaddress",
|
||||
}
|
||||
EMAIL_ATTRIBUTE_KEYS_LOWER = {key.lower() for key in EMAIL_ATTRIBUTE_KEYS}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
"""
|
||||
@@ -204,16 +217,37 @@ async def _process_saml_callback(
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user_email = auth.get_attribute("email")
|
||||
if not user_email:
|
||||
detail = "SAML is not set up correctly, email attribute must be provided."
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
user_email: str | None = None
|
||||
|
||||
user_email = user_email[0]
|
||||
# The OneLogin toolkit normalizes attribute keys, but still performs a
|
||||
# case-sensitive lookup. Try the common keys first and then fall back to a
|
||||
# case-insensitive scan of all returned attributes.
|
||||
for attribute_key in EMAIL_ATTRIBUTE_KEYS:
|
||||
attribute_values = auth.get_attribute(attribute_key)
|
||||
if attribute_values:
|
||||
user_email = attribute_values[0]
|
||||
break
|
||||
|
||||
if not user_email:
|
||||
# Fallback: perform a case-insensitive lookup across all attributes in
|
||||
# case the IdP sent the email claim with a different capitalization.
|
||||
attributes = auth.get_attributes()
|
||||
for key, values in attributes.items():
|
||||
if key.lower() in EMAIL_ATTRIBUTE_KEYS_LOWER:
|
||||
if values:
|
||||
user_email = values[0]
|
||||
break
|
||||
if not user_email:
|
||||
detail = "SAML is not set up correctly, email attribute must be provided."
|
||||
logger.error(detail)
|
||||
logger.debug(
|
||||
"Received SAML attributes without email: %s",
|
||||
list(attributes.keys()),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user = await upsert_saml_user(email=user_email)
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in ANTHROPIC_MODEL_NAMES
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
@@ -141,7 +148,9 @@ def get_local_intent_model(
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
) -> "SetFitModel":
|
||||
from setfit import SetFitModel
|
||||
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
@@ -179,7 +188,7 @@ def get_local_information_content_model(
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
|
||||
@@ -401,7 +410,7 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
|
||||
@@ -2,13 +2,11 @@ import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from litellm.exceptions import RateLimitError
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -20,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
@@ -88,8 +89,10 @@ def get_embedding_model(
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> CrossEncoder:
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
@@ -207,6 +210,8 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
if not SKIP_WARM_UP:
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice("Warming up intent model for inference model server")
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"Warming up content information model for indexing model server"
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertModel # type: ignore
|
||||
from transformers import DistilBertTokenizer # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: DistilBertConfig) -> None:
|
||||
def __init__(self, config: "DistilBertConfig") -> None:
|
||||
from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from braintrust import traced
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import (
|
||||
BasicSearchProcessedStreamResults,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
@@ -666,28 +670,30 @@ def clarifier(
|
||||
system_prompt_to_use = assistant_system_prompt
|
||||
user_prompt_to_use = decision_prompt + assistant_task_prompt
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
|
||||
full_response = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
@traced(name="clarifier stream and process", type="llm")
|
||||
def stream_and_process() -> BasicSearchProcessedStreamResults:
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
return process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,6 +64,29 @@ def image_generation(
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
@@ -69,7 +94,15 @@ def image_generation(
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
for tool_response in image_tool.run(prompt=branch_query):
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
@@ -95,6 +128,7 @@ def image_generation(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
@@ -107,15 +141,29 @@ def image_generation(
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {branch_query}"
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
|
||||
@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(InternetSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_SEARCH_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> InternetContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return InternetContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> InternetContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_CONTENTS_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return InternetContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Response only guarantees text
|
||||
text = response_json["text"]
|
||||
|
||||
# metadata & jsonld is not guaranteed to be present
|
||||
metadata = response_json.get("metadata", {})
|
||||
jsonld = response_json.get("jsonld", {})
|
||||
|
||||
title = extract_title_from_metadata(metadata)
|
||||
|
||||
# Serper does not provide a reliable mechanism to extract the url
|
||||
response_url = url
|
||||
published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
published_date = None
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return InternetContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
|
||||
|
||||
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
keys = ["title", "og:title"]
|
||||
return extract_value_from_dict(metadata, keys)
|
||||
|
||||
|
||||
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
keys = ["dateModified"]
|
||||
return extract_value_from_dict(jsonld, keys)
|
||||
|
||||
|
||||
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
@@ -26,6 +26,7 @@ class InternetContent(BaseModel):
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: datetime | None = None
|
||||
scrape_successful: bool = True
|
||||
|
||||
|
||||
class InternetSearchProvider(ABC):
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> InternetSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
return SerperClient()
|
||||
return None
|
||||
|
||||
@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
hidden=(not result.scrape_successful),
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
def call_tool(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
try:
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
f"Error during tool call for {tool.display_name}: {e}"
|
||||
) from e
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
@@ -1,354 +0,0 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.tools.models import QueryExpansions
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
# TODO: Add trimming logic
|
||||
history_segments = []
|
||||
for msg in prompt_builder.message_history:
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "User"
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "Assistant"
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
def _expand_query(
|
||||
query: str,
|
||||
expansion_type: QueryExpansionType,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
) -> str:
|
||||
|
||||
history_str = _create_history_str(prompt_builder)
|
||||
|
||||
if history_str:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||
else:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query)
|
||||
|
||||
msg = HumanMessage(content=expansion_prompt)
|
||||
primary_llm, _ = get_default_llms()
|
||||
response = primary_llm.invoke([msg])
|
||||
rephrased_query: str = cast(str, response.content)
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
def _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread: TimeoutThread[str],
|
||||
expanded_semantic_thread: TimeoutThread[str],
|
||||
) -> QueryExpansions | None:
|
||||
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
|
||||
|
||||
if keyword_expansion is None or semantic_expansion is None:
|
||||
return None
|
||||
|
||||
return QueryExpansions(
|
||||
keywords_expansions=[keyword_expansion],
|
||||
semantic_expansions=[semantic_expansion],
|
||||
)
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
override_kwargs: SearchToolOverrideKwargs = (
|
||||
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
|
||||
)
|
||||
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
|
||||
)
|
||||
):
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
|
||||
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
|
||||
|
||||
expanded_keyword_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.KEYWORD,
|
||||
prompt_builder,
|
||||
)
|
||||
expanded_semantic_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
prompt_builder,
|
||||
)
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
# dual keyword expansion needs to be added here for non-tool calling LLM case
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
and tool.name == SearchTool._NAME
|
||||
):
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in tools] or None
|
||||
if using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tool_choice=(
|
||||
"required"
|
||||
if tools and force_use_tool.force_use and using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
ind=0,
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("No tool message emitted by LLM")
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.debug("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
if (
|
||||
selected_tool.name == SearchTool._NAME
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
):
|
||||
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and selected_tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
# TODO: this is a hack to handle the case where the expanded queries are not found.
|
||||
# We should refactor this to be more robust.
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||
)
|
||||
@@ -5,11 +5,10 @@ from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
@@ -29,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -147,6 +147,7 @@ def invoke_llm_json(
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
|
||||
@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import ONYX_DISCORD_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
{community_link_fragment}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -161,9 +161,9 @@ def build_html_email(
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
community_link_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
@@ -175,7 +175,7 @@ def build_html_email(
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
community_link_fragment=community_link_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: Iterator[list[Document]],
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs()
|
||||
)
|
||||
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
)
|
||||
# If the connector isn't slim, fall back to running it normally to get ids
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
doc_batch_processing_func = (
|
||||
rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(lambda x: x)
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
@@ -41,7 +41,7 @@ beat_task_templates: list[dict] = [
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
@@ -85,9 +85,9 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(minutes=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
|
||||
tenant_id: str,
|
||||
batch_num: int,
|
||||
) -> None:
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
if tenant_id:
|
||||
|
||||
@@ -579,6 +579,16 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
# Process each UserFile and update matching SearchDocs
|
||||
updated_count = 0
|
||||
for uf in user_files:
|
||||
@@ -586,9 +596,18 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
if doc_id.startswith("USER_FILE_CONNECTOR__"):
|
||||
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
|
||||
|
||||
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
|
||||
task_logger.debug(
|
||||
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
|
||||
)
|
||||
|
||||
if doc_id in search_doc_map:
|
||||
search_docs = search_doc_map[doc_id]
|
||||
task_logger.debug(
|
||||
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
|
||||
)
|
||||
# Update the SearchDoc to use the UserFile's UUID
|
||||
for search_doc in search_doc_map[doc_id]:
|
||||
for search_doc in search_docs:
|
||||
search_doc.document_id = str(uf.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
def get_old_index_attempts(
|
||||
@@ -21,6 +22,10 @@ def get_old_index_attempts(
|
||||
|
||||
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
|
||||
"""Clean up multiple index attempts"""
|
||||
db_session.query(IndexAttemptError).filter(
|
||||
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.query(IndexAttempt).filter(
|
||||
IndexAttempt.id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
@@ -101,7 +102,6 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -113,6 +112,10 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
class PartialResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -24,8 +24,6 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
@@ -756,7 +754,7 @@ MAX_FEDERATED_CHUNKS = int(
|
||||
# NOTE: this should only be enabled if you have purchased an enterprise license.
|
||||
# if you're interested in an enterprise license, please reach out to us at
|
||||
# founders@onyx.app OR message Chris Weaver or Yuhong Sun in the Onyx
|
||||
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
# Discord community https://discord.gg/4NA5SbzrWb
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -90,6 +90,7 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
|
||||
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
|
||||
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
@@ -6,7 +6,7 @@ from enum import Enum
|
||||
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
|
||||
|
||||
#### Implementing the new Connector
|
||||
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
class BitbucketConnector(
|
||||
CheckpointedConnector[BitbucketConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
"""Connector for indexing Bitbucket Cloud pull requests.
|
||||
|
||||
@@ -266,7 +266,7 @@ class BitbucketConnector(
|
||||
"""Validate and deserialize a checkpoint instance from JSON."""
|
||||
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian.errors import ApiError # type: ignore
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -41,6 +42,7 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -91,6 +93,7 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
class ConfluenceConnector(
|
||||
CheckpointedConnector[ConfluenceCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
):
|
||||
def __init__(
|
||||
@@ -108,6 +111,7 @@ class ConfluenceConnector(
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
self.is_cloud = is_cloud
|
||||
@@ -118,6 +122,7 @@ class ConfluenceConnector(
|
||||
self.batch_size = batch_size
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self.scoped_token = scoped_token
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
@@ -195,6 +200,7 @@ class ConfluenceConnector(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -207,6 +213,7 @@ class ConfluenceConnector(
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
timeout=3,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
|
||||
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -558,7 +565,21 @@ class ConfluenceConnector(
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
|
||||
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -568,12 +589,28 @@ class ConfluenceConnector(
|
||||
Return 'slim' docs (IDs + minimal permission data).
|
||||
Does not fetch actual text. Used primarily for incremental permission sync.
|
||||
"""
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=True,
|
||||
)
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
space_level_access_info: dict[str, ExternalAccess] = {}
|
||||
if include_permissions:
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
|
||||
def get_external_access(
|
||||
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
|
||||
@@ -600,8 +637,10 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=page_id,
|
||||
external_access=get_external_access(
|
||||
page_id, page_restrictions, page_ancestors
|
||||
external_access=(
|
||||
get_external_access(page_id, page_restrictions, page_ancestors)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -636,8 +675,12 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=attachment_id,
|
||||
external_access=get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
external_access=(
|
||||
get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -648,10 +691,10 @@ class ConfluenceConnector(
|
||||
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
if callback:
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -676,6 +719,14 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
except ApiError as e:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
@@ -724,7 +775,7 @@ if __name__ == "__main__":
|
||||
end = datetime.now().timestamp()
|
||||
|
||||
# Fetch all `SlimDocuments`.
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_documents():
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
|
||||
print(slim_doc)
|
||||
|
||||
# Fetch all `Documents`.
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -87,16 +88,20 @@ class OnyxConfluence:
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
scoped_token: bool = False,
|
||||
# should generally not be passed in, but making it overridable for
|
||||
# easier testing
|
||||
confluence_user_profiles_override: list[dict[str, str]] | None = (
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
),
|
||||
) -> None:
|
||||
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
|
||||
url = scoped_url(url, "confluence") if scoped_token else url
|
||||
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
@@ -218,6 +223,34 @@ class OnyxConfluence:
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
logger.warning(
|
||||
"scoped token authenticated but not valid for probe endpoint (spaces)"
|
||||
)
|
||||
else:
|
||||
if "WWW-Authenticate" in e.response.headers:
|
||||
logger.warning(
|
||||
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
|
||||
)
|
||||
logger.warning(f"Full error: {e.response.text}")
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
@@ -236,6 +269,7 @@ class OnyxConfluence:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
@@ -304,7 +338,9 @@ class OnyxConfluence:
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
logger.info(
|
||||
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
|
||||
)
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
|
||||
@@ -5,7 +5,10 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
|
||||
def is_atlassian_date_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["cloudId"]
|
||||
|
||||
|
||||
def scoped_url(url: str, product: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
base_url = parsed.scheme + "://" + parsed.netloc
|
||||
cloud_id = get_cloudId(base_url)
|
||||
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.bitbucket.connector import BitbucketConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.imap.connector import ImapConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.outline.connector import OutlineConnector
|
||||
from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported connector classes
|
||||
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
|
||||
|
||||
|
||||
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
|
||||
"""Dynamically load and cache a connector class."""
|
||||
if source in _connector_cache:
|
||||
return _connector_cache[source]
|
||||
|
||||
if source not in CONNECTOR_CLASS_MAP:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
|
||||
mapping = CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_connector_supports_input_type(
|
||||
connector: Type[BaseConnector],
|
||||
input_type: InputType | None,
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""Validate that a connector supports the requested input type."""
|
||||
if input_type is None:
|
||||
return
|
||||
|
||||
# Check each input type requirement separately for clarity
|
||||
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
|
||||
connector, LoadConnector
|
||||
)
|
||||
|
||||
poll_unsupported = (
|
||||
input_type == InputType.POLL
|
||||
# Either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
)
|
||||
|
||||
event_unsupported = input_type == InputType.EVENT and not issubclass(
|
||||
connector, EventConnector
|
||||
)
|
||||
|
||||
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
|
||||
def identify_connector_class(
|
||||
source: DocumentSource,
|
||||
input_type: InputType | None = None,
|
||||
) -> Type[BaseConnector]:
|
||||
connector_map = {
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.OUTLINE: OutlineConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
DocumentSource.JIRA: JiraConnector,
|
||||
DocumentSource.PRODUCTBOARD: ProductboardConnector,
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
DocumentSource.DOCUMENT360: Document360Connector,
|
||||
DocumentSource.GONG: GongConnector,
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.DROPBOX: DropboxConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
DocumentSource.TEAMS: TeamsConnector,
|
||||
DocumentSource.SALESFORCE: SalesforceConnector,
|
||||
DocumentSource.DISCOURSE: DiscourseConnector,
|
||||
DocumentSource.AXERO: AxeroConnector,
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
DocumentSource.BITBUCKET: BitbucketConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
# Load the connector class using lazy loading
|
||||
connector = _load_connector_class(source)
|
||||
|
||||
if isinstance(connector_by_source, dict):
|
||||
if input_type is None:
|
||||
# If not specified, default to most exhaustive update
|
||||
connector = connector_by_source.get(InputType.LOAD_STATE)
|
||||
else:
|
||||
connector = connector_by_source.get(input_type)
|
||||
else:
|
||||
connector = connector_by_source
|
||||
if connector is None:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
# Validate connector supports the requested input_type
|
||||
_validate_connector_supports_input_type(connector, input_type, source)
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _safe_get(attr_name: str) -> str | None:
|
||||
try:
|
||||
return cast(str | None, getattr(user, attr_name))
|
||||
except GithubException:
|
||||
logger.debug(f"Error getting {attr_name} for user")
|
||||
return None
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in {
|
||||
"login": user.login,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"login": _safe_get("login"),
|
||||
"name": _safe_get("name"),
|
||||
"email": _safe_get("email"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
@@ -232,7 +232,7 @@ def thread_to_document(
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -397,10 +397,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
except HttpError as e:
|
||||
if _is_mail_service_disabled_error(e):
|
||||
logger.warning(
|
||||
@@ -431,7 +431,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1296,7 +1296,7 @@ class GoogleDriveConnector(
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""
|
||||
Connector for loading data from Highspot.
|
||||
|
||||
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
description = item_details.get("description", "")
|
||||
return title, description
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from hubspot import HubSpot # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self._access_token = access_token
|
||||
self._portal_id: str | None = None
|
||||
self._rate_limiter = HubSpotRateLimiter()
|
||||
|
||||
# Set object types to fetch, default to all available types
|
||||
if object_types is None:
|
||||
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
"""Set the portal ID."""
|
||||
self._portal_id = value
|
||||
|
||||
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
return self._rate_limiter.call(func, *args, **kwargs)
|
||||
|
||||
def _paginated_results(
|
||||
self,
|
||||
fetch_page: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Generator[Any, None, None]:
|
||||
base_kwargs = dict(kwargs)
|
||||
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
|
||||
|
||||
after: str | None = None
|
||||
while True:
|
||||
page_kwargs = base_kwargs.copy()
|
||||
if after is not None:
|
||||
page_kwargs["after"] = after
|
||||
|
||||
page = self._call_hubspot(fetch_page, **page_kwargs)
|
||||
results = getattr(page, "results", [])
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
paging = getattr(page, "paging", None)
|
||||
next_page = getattr(paging, "next", None) if paging else None
|
||||
if next_page is None:
|
||||
break
|
||||
|
||||
after = getattr(next_page, "after", None)
|
||||
if after is None:
|
||||
break
|
||||
|
||||
def _clean_html_content(self, html_content: str) -> str:
|
||||
"""Clean HTML content and extract raw text"""
|
||||
if not html_content:
|
||||
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get associated objects for a given object"""
|
||||
try:
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=from_object_type,
|
||||
object_id=object_id,
|
||||
to_object_type=to_object_type,
|
||||
)
|
||||
|
||||
associated_objects = []
|
||||
if associations.results:
|
||||
object_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
object_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated objects
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.contacts.basic_api.get_by_id(
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
associated_objects: list[dict[str, Any]] = []
|
||||
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.companies.basic_api.get_by_id(
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.contacts.basic_api.get_by_id,
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.deals.basic_api.get_by_id(
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.companies.basic_api.get_by_id,
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.tickets.basic_api.get_by_id(
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.deals.basic_api.get_by_id,
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.tickets.basic_api.get_by_id,
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
|
||||
return associated_objects
|
||||
|
||||
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get notes associated with a given object"""
|
||||
try:
|
||||
# Get associations to notes (engagement type)
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=object_type,
|
||||
object_id=object_id,
|
||||
to_object_type="notes",
|
||||
)
|
||||
|
||||
associated_notes = []
|
||||
if associations.results:
|
||||
note_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
note_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated notes
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = api_client.crm.objects.notes.basic_api.get_by_id(
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
associated_notes = []
|
||||
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = self._call_hubspot(
|
||||
api_client.crm.objects.notes.basic_api.get_by_id,
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
|
||||
return associated_notes
|
||||
|
||||
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_tickets = api_client.crm.tickets.get_all(
|
||||
|
||||
tickets_iter = self._paginated_results(
|
||||
api_client.crm.tickets.basic_api.get_page,
|
||||
properties=[
|
||||
"subject",
|
||||
"content",
|
||||
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for ticket in all_tickets:
|
||||
for ticket in tickets_iter:
|
||||
updated_at = ticket.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_companies = api_client.crm.companies.get_all(
|
||||
|
||||
companies_iter = self._paginated_results(
|
||||
api_client.crm.companies.basic_api.get_page,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for company in all_companies:
|
||||
for company in companies_iter:
|
||||
updated_at = company.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_deals = api_client.crm.deals.get_all(
|
||||
|
||||
deals_iter = self._paginated_results(
|
||||
api_client.crm.deals.basic_api.get_page,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for deal in all_deals:
|
||||
for deal in deals_iter:
|
||||
updated_at = deal.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_contacts = api_client.crm.contacts.get_all(
|
||||
|
||||
contacts_iter = self._paginated_results(
|
||||
api_client.crm.contacts.basic_api.get_page,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for contact in all_contacts:
|
||||
for contact in contacts_iter:
|
||||
updated_at = contact.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
|
||||
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
RateLimitTriedTooManyTimesError,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds)
|
||||
# with a maximum of 190 requests, and a per-second limit of 19 requests.
|
||||
_HUBSPOT_TEN_SECOND_LIMIT = 190
|
||||
_HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds
|
||||
_HUBSPOT_SECONDLY_LIMIT = 19
|
||||
_HUBSPOT_SECONDLY_PERIOD = 1 # second
|
||||
_DEFAULT_SLEEP_SECONDS = 10
|
||||
_SLEEP_PADDING_SECONDS = 1.0
|
||||
_MAX_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
|
||||
def _extract_header(headers: Any, key: str) -> str | None:
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
getter = getattr(headers, "get", None)
|
||||
if callable(getter):
|
||||
value = getter(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
if isinstance(headers, dict):
|
||||
value = headers.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_rate_limit_error(exception: Exception) -> bool:
|
||||
status = getattr(exception, "status", None)
|
||||
if status == 429:
|
||||
return True
|
||||
|
||||
headers = getattr(exception, "headers", None)
|
||||
if headers is not None:
|
||||
remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining")
|
||||
if remaining == "0":
|
||||
return True
|
||||
secondly_remaining = _extract_header(
|
||||
headers, "x-hubspot-ratelimit-secondly-remaining"
|
||||
)
|
||||
if secondly_remaining == "0":
|
||||
return True
|
||||
|
||||
message = str(exception)
|
||||
return "RATE_LIMIT" in message or "Too Many Requests" in message
|
||||
|
||||
|
||||
def get_rate_limit_retry_delay_seconds(exception: Exception) -> float:
|
||||
headers = getattr(exception, "headers", None)
|
||||
|
||||
retry_after = _extract_header(headers, "Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
return float(retry_after) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse Retry-After header '%s' as float", retry_after
|
||||
)
|
||||
|
||||
interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds")
|
||||
if interval_ms:
|
||||
try:
|
||||
return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float",
|
||||
interval_ms,
|
||||
)
|
||||
|
||||
secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly")
|
||||
if secondly_limit:
|
||||
try:
|
||||
per_second = max(float(secondly_limit), 1.0)
|
||||
return (1.0 / per_second) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-secondly '%s' as float",
|
||||
secondly_limit,
|
||||
)
|
||||
|
||||
return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS
|
||||
|
||||
|
||||
class HubSpotRateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT,
|
||||
ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD,
|
||||
secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT,
|
||||
secondly_period: int = _HUBSPOT_SECONDLY_PERIOD,
|
||||
max_retries: int = _MAX_RATE_LIMIT_RETRIES,
|
||||
) -> None:
|
||||
self._max_retries = max_retries
|
||||
|
||||
@rate_limit_builder(max_calls=secondly_limit, period=secondly_period)
|
||||
@rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period)
|
||||
def _execute(callable_: Callable[[], T]) -> T:
|
||||
return callable_()
|
||||
|
||||
self._execute = _execute
|
||||
|
||||
def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
attempts = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self._execute(lambda: func(*args, **kwargs))
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
if not is_rate_limit_error(exc):
|
||||
raise
|
||||
|
||||
attempts += 1
|
||||
if attempts > self._max_retries:
|
||||
raise RateLimitTriedTooManyTimesError(
|
||||
"Exceeded configured HubSpot rate limit retries"
|
||||
) from exc
|
||||
|
||||
wait_time = get_rate_limit_retry_delay_seconds(exc)
|
||||
logger.notice(
|
||||
"HubSpot rate limit reached. Sleeping %.2f seconds before retrying.",
|
||||
wait_time,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
@@ -97,11 +97,20 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors can retrieve just the ids and
|
||||
# permission syncing information for connected documents
|
||||
# Slim connectors retrieve just the ids of documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors retrieve both the ids AND
|
||||
# permission syncing information for connected documents
|
||||
class SlimConnectorWithPermSync(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -25,11 +25,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.jira.access import get_project_permissions
|
||||
from onyx.connectors.jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
|
||||
@@ -247,7 +247,7 @@ def _perform_jql_search_v2(
|
||||
|
||||
|
||||
def process_jira_issue(
|
||||
jira_client: JIRA,
|
||||
jira_base_url: str,
|
||||
issue: Issue,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
@@ -281,7 +281,7 @@ def process_jira_issue(
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
page_url = build_jira_url(jira_base_url, issue.key)
|
||||
|
||||
metadata_dict: dict[str, str | list[str]] = {}
|
||||
people = set()
|
||||
@@ -359,7 +359,10 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
|
||||
class JiraConnector(
|
||||
CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@@ -372,15 +375,23 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
# Custom JQL query to filter Jira issues
|
||||
jql_query: str | None = None,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
# dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net
|
||||
# when making jira requests but still want correct links to issues in the UI.
|
||||
# So, the user's base url is stored here, but converted to a scoped url when passed
|
||||
# to the jira client.
|
||||
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
|
||||
self.jira_project = project_key
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.jql_query = jql_query
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self._jira_client: JIRA | None = None
|
||||
# Cache project permissions to avoid fetching them repeatedly across runs
|
||||
self._project_permissions_cache: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
@@ -399,10 +410,26 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
return ""
|
||||
return f'"{self.jira_project}"'
|
||||
|
||||
def _get_project_permissions(self, project_key: str) -> Any:
|
||||
"""Get project permissions with caching.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key
|
||||
|
||||
Returns:
|
||||
The external access permissions for the project
|
||||
"""
|
||||
if project_key not in self._project_permissions_cache:
|
||||
self._project_permissions_cache[project_key] = get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
)
|
||||
return self._project_permissions_cache[project_key]
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -442,15 +469,37 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
raise e
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint, include_permissions=True)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=True
|
||||
)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
@@ -472,18 +521,25 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
jira_base_url=self.jira_base,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
project_key = get_jira_project_key_from_issue(issue=issue)
|
||||
if project_key:
|
||||
document.external_access = self._get_project_permissions(
|
||||
project_key
|
||||
)
|
||||
yield document
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
document_link=build_jira_url(self.jira_base, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
@@ -515,7 +571,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
checkpoint.has_more = current_offset - starting_offset == page_size
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -534,6 +590,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_doc_batch = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
@@ -550,13 +607,12 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
continue
|
||||
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
id = build_jira_url(self.jira_base, issue_key)
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
external_access=get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
),
|
||||
external_access=self._get_project_permissions(project_key),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
@@ -701,7 +757,7 @@ if __name__ == "__main__":
|
||||
start = 0
|
||||
end = datetime.now().timestamp()
|
||||
|
||||
for slim_doc in connector.retrieve_all_slim_documents(
|
||||
for slim_doc in connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
|
||||
@@ -10,6 +10,7 @@ from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -74,11 +75,18 @@ def extract_text_from_adf(adf: dict | None) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
|
||||
return f"{jira_client.client_info()}/browse/{issue_key}"
|
||||
def build_jira_url(jira_base_url: str, issue_key: str) -> str:
|
||||
"""
|
||||
Get the url used to access an issue in the UI.
|
||||
"""
|
||||
return f"{jira_base_url}/browse/{issue_key}"
|
||||
|
||||
|
||||
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
def build_jira_client(
|
||||
credentials: dict[str, Any], jira_base: str, scoped_token: bool = False
|
||||
) -> JIRA:
|
||||
|
||||
jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
|
||||
208
backend/onyx/connectors/registry.py
Normal file
208
backend/onyx/connectors/registry.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Registry mapping for connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
class ConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of DocumentSource to connector details for lazy loading
|
||||
CONNECTOR_CLASS_MAP = {
|
||||
DocumentSource.WEB: ConnectorMapping(
|
||||
module_path="onyx.connectors.web.connector",
|
||||
class_name="WebConnector",
|
||||
),
|
||||
DocumentSource.FILE: ConnectorMapping(
|
||||
module_path="onyx.connectors.file.connector",
|
||||
class_name="LocalFileConnector",
|
||||
),
|
||||
DocumentSource.SLACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.slack.connector",
|
||||
class_name="SlackConnector",
|
||||
),
|
||||
DocumentSource.GITHUB: ConnectorMapping(
|
||||
module_path="onyx.connectors.github.connector",
|
||||
class_name="GithubConnector",
|
||||
),
|
||||
DocumentSource.GMAIL: ConnectorMapping(
|
||||
module_path="onyx.connectors.gmail.connector",
|
||||
class_name="GmailConnector",
|
||||
),
|
||||
DocumentSource.GITLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitlab.connector",
|
||||
class_name="GitlabConnector",
|
||||
),
|
||||
DocumentSource.GITBOOK: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitbook.connector",
|
||||
class_name="GitbookConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_DRIVE: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_drive.connector",
|
||||
class_name="GoogleDriveConnector",
|
||||
),
|
||||
DocumentSource.BOOKSTACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.bookstack.connector",
|
||||
class_name="BookstackConnector",
|
||||
),
|
||||
DocumentSource.OUTLINE: ConnectorMapping(
|
||||
module_path="onyx.connectors.outline.connector",
|
||||
class_name="OutlineConnector",
|
||||
),
|
||||
DocumentSource.CONFLUENCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.confluence.connector",
|
||||
class_name="ConfluenceConnector",
|
||||
),
|
||||
DocumentSource.JIRA: ConnectorMapping(
|
||||
module_path="onyx.connectors.jira.connector",
|
||||
class_name="JiraConnector",
|
||||
),
|
||||
DocumentSource.PRODUCTBOARD: ConnectorMapping(
|
||||
module_path="onyx.connectors.productboard.connector",
|
||||
class_name="ProductboardConnector",
|
||||
),
|
||||
DocumentSource.SLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.slab.connector",
|
||||
class_name="SlabConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
),
|
||||
DocumentSource.ZULIP: ConnectorMapping(
|
||||
module_path="onyx.connectors.zulip.connector",
|
||||
class_name="ZulipConnector",
|
||||
),
|
||||
DocumentSource.GURU: ConnectorMapping(
|
||||
module_path="onyx.connectors.guru.connector",
|
||||
class_name="GuruConnector",
|
||||
),
|
||||
DocumentSource.LINEAR: ConnectorMapping(
|
||||
module_path="onyx.connectors.linear.connector",
|
||||
class_name="LinearConnector",
|
||||
),
|
||||
DocumentSource.HUBSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.hubspot.connector",
|
||||
class_name="HubSpotConnector",
|
||||
),
|
||||
DocumentSource.DOCUMENT360: ConnectorMapping(
|
||||
module_path="onyx.connectors.document360.connector",
|
||||
class_name="Document360Connector",
|
||||
),
|
||||
DocumentSource.GONG: ConnectorMapping(
|
||||
module_path="onyx.connectors.gong.connector",
|
||||
class_name="GongConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_SITES: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_site.connector",
|
||||
class_name="GoogleSitesConnector",
|
||||
),
|
||||
DocumentSource.ZENDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.zendesk.connector",
|
||||
class_name="ZendeskConnector",
|
||||
),
|
||||
DocumentSource.LOOPIO: ConnectorMapping(
|
||||
module_path="onyx.connectors.loopio.connector",
|
||||
class_name="LoopioConnector",
|
||||
),
|
||||
DocumentSource.DROPBOX: ConnectorMapping(
|
||||
module_path="onyx.connectors.dropbox.connector",
|
||||
class_name="DropboxConnector",
|
||||
),
|
||||
DocumentSource.SHAREPOINT: ConnectorMapping(
|
||||
module_path="onyx.connectors.sharepoint.connector",
|
||||
class_name="SharepointConnector",
|
||||
),
|
||||
DocumentSource.TEAMS: ConnectorMapping(
|
||||
module_path="onyx.connectors.teams.connector",
|
||||
class_name="TeamsConnector",
|
||||
),
|
||||
DocumentSource.SALESFORCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.salesforce.connector",
|
||||
class_name="SalesforceConnector",
|
||||
),
|
||||
DocumentSource.DISCOURSE: ConnectorMapping(
|
||||
module_path="onyx.connectors.discourse.connector",
|
||||
class_name="DiscourseConnector",
|
||||
),
|
||||
DocumentSource.AXERO: ConnectorMapping(
|
||||
module_path="onyx.connectors.axero.connector",
|
||||
class_name="AxeroConnector",
|
||||
),
|
||||
DocumentSource.CLICKUP: ConnectorMapping(
|
||||
module_path="onyx.connectors.clickup.connector",
|
||||
class_name="ClickupConnector",
|
||||
),
|
||||
DocumentSource.MEDIAWIKI: ConnectorMapping(
|
||||
module_path="onyx.connectors.mediawiki.wiki",
|
||||
class_name="MediaWikiConnector",
|
||||
),
|
||||
DocumentSource.WIKIPEDIA: ConnectorMapping(
|
||||
module_path="onyx.connectors.wikipedia.connector",
|
||||
class_name="WikipediaConnector",
|
||||
),
|
||||
DocumentSource.ASANA: ConnectorMapping(
|
||||
module_path="onyx.connectors.asana.connector",
|
||||
class_name="AsanaConnector",
|
||||
),
|
||||
DocumentSource.S3: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.R2: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.OCI_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.XENFORO: ConnectorMapping(
|
||||
module_path="onyx.connectors.xenforo.connector",
|
||||
class_name="XenforoConnector",
|
||||
),
|
||||
DocumentSource.DISCORD: ConnectorMapping(
|
||||
module_path="onyx.connectors.discord.connector",
|
||||
class_name="DiscordConnector",
|
||||
),
|
||||
DocumentSource.FRESHDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.freshdesk.connector",
|
||||
class_name="FreshdeskConnector",
|
||||
),
|
||||
DocumentSource.FIREFLIES: ConnectorMapping(
|
||||
module_path="onyx.connectors.fireflies.connector",
|
||||
class_name="FirefliesConnector",
|
||||
),
|
||||
DocumentSource.EGNYTE: ConnectorMapping(
|
||||
module_path="onyx.connectors.egnyte.connector",
|
||||
class_name="EgnyteConnector",
|
||||
),
|
||||
DocumentSource.AIRTABLE: ConnectorMapping(
|
||||
module_path="onyx.connectors.airtable.airtable_connector",
|
||||
class_name="AirtableConnector",
|
||||
),
|
||||
DocumentSource.HIGHSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.highspot.connector",
|
||||
class_name="HighspotConnector",
|
||||
),
|
||||
DocumentSource.IMAP: ConnectorMapping(
|
||||
module_path="onyx.connectors.imap.connector",
|
||||
class_name="ImapConnector",
|
||||
),
|
||||
DocumentSource.BITBUCKET: ConnectorMapping(
|
||||
module_path="onyx.connectors.bitbucket.connector",
|
||||
class_name="BitbucketConnector",
|
||||
),
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
|
||||
module_path="onyx.connectors.mock_connector.connector",
|
||||
class_name="MockConnector",
|
||||
),
|
||||
}
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -151,7 +151,7 @@ def _validate_custom_query_config(config: dict[str, Any]) -> None:
|
||||
)
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""Approach outline
|
||||
|
||||
Goal
|
||||
@@ -1119,7 +1119,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -41,7 +41,7 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import IndexingHeartbeatInterface
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -73,7 +73,8 @@ class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
Args:
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests
|
||||
or https://danswerai.sharepoint.com/teams/team-name)
|
||||
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||
If None, all drives will be accessed.
|
||||
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||
@@ -672,7 +673,7 @@ def _convert_sitepage_to_slim_document(
|
||||
|
||||
|
||||
class SharepointConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CheckpointedConnectorWithPermSync[SharepointConnectorCheckpoint],
|
||||
):
|
||||
def __init__(
|
||||
@@ -703,9 +704,11 @@ class SharepointConnector(
|
||||
|
||||
# Ensure sites are sharepoint urls
|
||||
for site_url in self.sites:
|
||||
if not site_url.startswith("https://") or "/sites/" not in site_url:
|
||||
if not site_url.startswith("https://") or not (
|
||||
"/sites/" in site_url or "/teams/" in site_url
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site)"
|
||||
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -720,10 +723,17 @@ class SharepointConnector(
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
|
||||
site_type_index = None
|
||||
if "sites" in parts:
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
remaining_parts = parts[sites_index + 2 :]
|
||||
site_type_index = parts.index("sites")
|
||||
elif "teams" in parts:
|
||||
site_type_index = parts.index("teams")
|
||||
|
||||
if site_type_index is not None:
|
||||
# Extract the base site URL (up to and including the site/team name)
|
||||
site_url = "/".join(parts[: site_type_index + 2])
|
||||
remaining_parts = parts[site_type_index + 2 :]
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
@@ -745,7 +755,9 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Site URL '{url}' is not a valid Sharepoint URL")
|
||||
logger.warning(
|
||||
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)"
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _get_drive_items_for_drive_name(
|
||||
@@ -1597,7 +1609,7 @@ class SharepointConnector(
|
||||
) -> SharepointConnectorCheckpoint:
|
||||
return SharepointConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -164,7 +164,7 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -239,7 +239,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -42,7 +42,7 @@ from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -581,7 +581,7 @@ def _process_message(
|
||||
|
||||
|
||||
class SlackConnector(
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
CheckpointedConnectorWithPermSync[SlackCheckpoint],
|
||||
):
|
||||
@@ -732,7 +732,7 @@ class SlackConnector(
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -51,7 +51,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnector[TeamsCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
@@ -228,9 +228,9 @@ class TeamsConnector(
|
||||
has_more=bool(todos),
|
||||
)
|
||||
|
||||
# impls for SlimConnector
|
||||
# impls for SlimConnectorWithPermSync
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -572,7 +572,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
teams_connector.validate_connector_settings()
|
||||
|
||||
for slim_doc in teams_connector.retrieve_all_slim_documents():
|
||||
for slim_doc in teams_connector.retrieve_all_slim_docs_perm_sync():
|
||||
...
|
||||
|
||||
for doc in load_everything_from_checkpoint_connector(
|
||||
|
||||
@@ -219,6 +219,25 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _same_site(base_url: str, candidate_url: str) -> bool:
|
||||
base, candidate = urlparse(base_url), urlparse(candidate_url)
|
||||
base_netloc = base.netloc.lower().removeprefix("www.")
|
||||
candidate_netloc = candidate.netloc.lower().removeprefix("www.")
|
||||
if base_netloc != candidate_netloc:
|
||||
return False
|
||||
|
||||
base_path = (base.path or "/").rstrip("/")
|
||||
if base_path in ("", "/"):
|
||||
return True
|
||||
|
||||
candidate_path = candidate.path or "/"
|
||||
if candidate_path == base_path:
|
||||
return True
|
||||
|
||||
boundary = f"{base_path}/"
|
||||
return candidate_path.startswith(boundary)
|
||||
|
||||
|
||||
def get_internal_links(
|
||||
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
|
||||
) -> set[str]:
|
||||
@@ -239,7 +258,7 @@ def get_internal_links(
|
||||
# Relative path handling
|
||||
href = urljoin(url, href)
|
||||
|
||||
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
|
||||
if _same_site(base_url, href):
|
||||
internal_links.add(href)
|
||||
return internal_links
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
@@ -376,7 +376,7 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
|
||||
class ZendeskConnector(
|
||||
SlimConnector, CheckpointedConnector[ZendeskConnectorCheckpoint]
|
||||
SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -565,7 +565,7 @@ class ZendeskConnector(
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -2,7 +2,6 @@ import string
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
import nltk # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
@@ -61,6 +60,8 @@ def _dedupe_chunks(
|
||||
|
||||
|
||||
def download_nltk_data() -> None:
|
||||
import nltk # type: ignore[import-untyped]
|
||||
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
# "wordnet": "corpora/wordnet", # Not in use
|
||||
|
||||
@@ -2,8 +2,6 @@ import string
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
@@ -119,6 +117,9 @@ def inference_section_from_chunks(
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
query = " ".join(keywords)
|
||||
|
||||
@@ -112,6 +112,7 @@ def upsert_llm_provider(
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
)
|
||||
|
||||
@@ -2353,6 +2353,8 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
llm_provider: Mapped["LLMProvider"] = relationship(
|
||||
"LLMProvider",
|
||||
back_populates="model_configurations",
|
||||
|
||||
@@ -1,15 +1,52 @@
|
||||
"""Factory for creating federated connector instances."""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.federated_connectors.interfaces import FederatedConnector
|
||||
from onyx.federated_connectors.slack.federated_connector import SlackFederatedConnector
|
||||
from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FederatedConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported federated connector classes
|
||||
_federated_connector_cache: dict[FederatedConnectorSource, Type[FederatedConnector]] = (
|
||||
{}
|
||||
)
|
||||
|
||||
|
||||
def _load_federated_connector_class(
|
||||
source: FederatedConnectorSource,
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Dynamically load and cache a federated connector class."""
|
||||
if source in _federated_connector_cache:
|
||||
return _federated_connector_cache[source]
|
||||
|
||||
if source not in FEDERATED_CONNECTOR_CLASS_MAP:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Federated connector not found for source={source}"
|
||||
)
|
||||
|
||||
mapping = FEDERATED_CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_federated_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def get_federated_connector(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -21,9 +58,6 @@ def get_federated_connector(
|
||||
|
||||
def get_federated_connector_cls(
|
||||
source: FederatedConnectorSource,
|
||||
) -> type[FederatedConnector]:
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Get the class of the appropriate federated connector."""
|
||||
if source == FederatedConnectorSource.FEDERATED_SLACK:
|
||||
return SlackFederatedConnector
|
||||
else:
|
||||
raise ValueError(f"Unsupported federated connector source: {source}")
|
||||
return _load_federated_connector_class(source)
|
||||
|
||||
19
backend/onyx/federated_connectors/registry.py
Normal file
19
backend/onyx/federated_connectors/registry.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry mapping for federated connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
|
||||
|
||||
class FederatedConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of FederatedConnectorSource to connector details for lazy loading
|
||||
FEDERATED_CONNECTOR_CLASS_MAP = {
|
||||
FederatedConnectorSource.FEDERATED_SLACK: FederatedConnectorMapping(
|
||||
module_path="onyx.federated_connectors.slack.federated_connector",
|
||||
class_name="SlackFederatedConnector",
|
||||
),
|
||||
}
|
||||
@@ -22,8 +22,6 @@ from zipfile import BadZipFile
|
||||
import chardet
|
||||
import openpyxl
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
@@ -272,6 +270,9 @@ def read_pdf_file(
|
||||
"""
|
||||
Returns the text, basic PDF metadata, and optionally extracted images.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
extracted_images: list[tuple[bytes, str]] = []
|
||||
try:
|
||||
@@ -313,10 +314,8 @@ def read_pdf_file(
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
image_name = (
|
||||
f"page_{page_num + 1}_image_{image_file_object.name}."
|
||||
f"{image.format.lower() if image.format else 'png'}"
|
||||
)
|
||||
image_format = image.format.lower() if image.format else "png"
|
||||
image_name = f"page_{page_num + 1}_image_{image_file_object.name}.{image_format}"
|
||||
if image_callback is not None:
|
||||
# Stream image out immediately
|
||||
image_callback(img_bytes, image_name)
|
||||
@@ -483,8 +482,7 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
|
||||
" skipping rest of file"
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
@@ -556,8 +554,7 @@ def extract_file_text(
|
||||
return unstructured_to_text(file, file_name)
|
||||
except Exception as unstructured_error:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. "
|
||||
"Falling back to normal processing."
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
|
||||
)
|
||||
if extension is None:
|
||||
extension = get_file_ext(file_name)
|
||||
@@ -643,8 +640,7 @@ def _extract_text_and_images(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(e)}. "
|
||||
"Falling back to normal processing."
|
||||
f"Failed to process with Unstructured: {str(e)}. Falling back to normal processing."
|
||||
)
|
||||
file.seek(0) # Reset file pointer just in case
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import bs4
|
||||
import trafilatura # type: ignore
|
||||
from trafilatura.settings import use_config # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
|
||||
from onyx.configs.app_configs import PARSE_WITH_TRAFILATURA
|
||||
@@ -56,6 +54,9 @@ def format_element_text(element_text: str, link_href: str | None) -> str:
|
||||
|
||||
def parse_html_with_trafilatura(html_content: str) -> str:
|
||||
"""Parse HTML content using trafilatura."""
|
||||
import trafilatura # type: ignore
|
||||
from trafilatura.settings import use_config # type: ignore
|
||||
|
||||
config = use_config()
|
||||
config.set("DEFAULT", "include_links", "True")
|
||||
config.set("DEFAULT", "include_tables", "True")
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.configs.constants import KV_UNSTRUCTURED_API_KEY
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -36,7 +35,10 @@ def delete_unstructured_api_key() -> None:
|
||||
|
||||
def _sdk_partition_request(
|
||||
file: IO[Any], file_name: str, **kwargs: Any
|
||||
) -> operations.PartitionRequest:
|
||||
) -> "operations.PartitionRequest":
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
|
||||
file.seek(0, 0)
|
||||
try:
|
||||
request = operations.PartitionRequest(
|
||||
@@ -52,6 +54,9 @@ def _sdk_partition_request(
|
||||
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
from nltk import ngrams # type: ignore
|
||||
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import Float
|
||||
@@ -59,6 +58,8 @@ def _normalize_one_entity(
|
||||
attributes: dict[str, str],
|
||||
allowed_docs_temp_view_name: str | None = None,
|
||||
) -> str | None:
|
||||
from nltk import ngrams # type: ignore
|
||||
|
||||
"""
|
||||
Matches a single entity to the best matching entity of the same type.
|
||||
"""
|
||||
|
||||
@@ -5,8 +5,9 @@ from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -24,9 +25,7 @@ from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
@@ -45,13 +44,9 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse, CustomStreamWrapper, Message
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
@@ -85,8 +80,10 @@ def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
|
||||
|
||||
def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: litellm.Message,
|
||||
litellm_message: "Message",
|
||||
) -> BaseMessage:
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
@@ -176,15 +173,15 @@ def _convert_delta_to_message_chunk(
|
||||
curr_msg: BaseMessage | None,
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
from litellm.utils import ChatCompletionDeltaToolCall
|
||||
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
|
||||
tool_calls = cast(
|
||||
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
|
||||
)
|
||||
tool_calls = cast(list[ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls"))
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
@@ -321,6 +318,8 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
|
||||
try:
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
params = get_supported_openai_params(model_name, model_provider)
|
||||
if STANDARD_MAX_TOKENS_KWARG in (params or []):
|
||||
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
|
||||
@@ -388,11 +387,13 @@ class DefaultMultiLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
from litellm.exceptions import Timeout, RateLimitError
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
@@ -456,12 +457,13 @@ class DefaultMultiLLM(LLM):
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, litellm.Timeout):
|
||||
if isinstance(e, Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
|
||||
elif isinstance(e, litellm.RateLimitError):
|
||||
elif isinstance(e, RateLimitError):
|
||||
raise LLMRateLimitError(e)
|
||||
|
||||
raise e
|
||||
@@ -495,11 +497,13 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
from litellm import ModelResponse
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
ModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
@@ -528,6 +532,8 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
from litellm import CustomStreamWrapper
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
@@ -544,7 +550,7 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
@@ -13,6 +10,8 @@ from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
@@ -24,13 +23,22 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
|
||||
return {}
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
|
||||
return {"Authorization": api_key}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
@@ -272,6 +280,16 @@ def get_llm(
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
|
||||
extra_headers = build_llm_extra_headers(additional_headers)
|
||||
|
||||
# NOTE: this is needed since Ollama API key is optional
|
||||
# User may access Ollama cloud via locally hosted instance (logged in)
|
||||
# or just via the cloud API (not logged in, using API key)
|
||||
provider_extra_headers = _build_provider_extra_headers(provider, custom_config)
|
||||
if provider_extra_headers:
|
||||
extra_headers.update(provider_extra_headers)
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@@ -282,8 +300,8 @@ def get_llm(
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
extra_headers=extra_headers,
|
||||
model_kwargs={},
|
||||
long_term_logger=long_term_logger,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
@@ -121,7 +121,6 @@ class LLM(abc.ABC):
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
23
backend/onyx/llm/litellm_singleton.py
Normal file
23
backend/onyx/llm/litellm_singleton.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Singleton module for litellm configuration.
|
||||
This ensures litellm is configured exactly once when first imported.
|
||||
All other modules should import litellm from here instead of directly.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
|
||||
# Import litellm
|
||||
|
||||
# Configure litellm settings immediately on import
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
# Export the configured litellm module
|
||||
__all__ = ["litellm"]
|
||||
@@ -1,6 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
import litellm # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
@@ -39,6 +38,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
model_configurations: list[ModelConfigurationView]
|
||||
default_model: str | None = None
|
||||
default_fast_model: str | None = None
|
||||
default_api_base: str | None = None
|
||||
# set for providers like Azure, which require a deployment name.
|
||||
deployment_name_required: bool = False
|
||||
# set for providers like Azure, which support a single model per deployment.
|
||||
@@ -86,16 +86,23 @@ OPEN_AI_VISIBLE_MODEL_NAMES = [
|
||||
]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
|
||||
OLLAMA_PROVIDER_NAME = "ollama"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
|
||||
def get_bedrock_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
|
||||
# litellm has split them into two lists :(
|
||||
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
return [
|
||||
model
|
||||
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
@@ -103,11 +110,18 @@ IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
]
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
ANTHROPIC_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
|
||||
|
||||
def get_anthropic_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
return [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
|
||||
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = [
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-20250514",
|
||||
@@ -155,18 +169,23 @@ VERTEXAI_VISIBLE_MODEL_NAMES = [
|
||||
]
|
||||
|
||||
|
||||
_PROVIDER_TO_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
}
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level."""
|
||||
return {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: get_bedrock_model_names(),
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
|
||||
BEDROCK_PROVIDER_NAME: [],
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
|
||||
@@ -185,6 +204,28 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
default_model="gpt-4o",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=OLLAMA_PROVIDER_NAME,
|
||||
display_name="Ollama",
|
||||
api_key_required=False,
|
||||
api_base_required=True,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[
|
||||
CustomConfigKey(
|
||||
name=OLLAMA_API_KEY_CONFIG_KEY,
|
||||
display_name="Ollama API Key",
|
||||
description="Optional API key used when connecting to Ollama Cloud (i.e. API base is https://ollama.com).",
|
||||
is_required=False,
|
||||
is_secret=True,
|
||||
)
|
||||
],
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
OLLAMA_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
default_api_base="http://127.0.0.1:11434",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
display_name="Anthropic",
|
||||
@@ -248,7 +289,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
default_model=BEDROCK_DEFAULT_MODEL,
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
@@ -287,7 +328,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
|
||||
|
||||
def fetch_models_for_provider(provider_name: str) -> list[str]:
|
||||
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])
|
||||
return _get_provider_to_models_map().get(provider_name, [])
|
||||
|
||||
|
||||
def fetch_model_names_for_provider_as_set(provider_name: str) -> set[str] | None:
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
@@ -26,6 +27,9 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -460,6 +464,7 @@ def get_llm_contextual_cost(
|
||||
this does not account for the cost of documents that fit within a single chunk
|
||||
which do not get contextualized.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
# calculate input costs
|
||||
@@ -639,6 +644,30 @@ def get_max_input_tokens_from_llm_provider(
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
# TODO: Add support to check model config for any provider
|
||||
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here
|
||||
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.join(
|
||||
LLMProvider,
|
||||
ModelConfiguration.llm_provider_id == LLMProvider.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.name == model_name,
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
)
|
||||
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = find_model_obj(
|
||||
|
||||
@@ -25,7 +25,6 @@ from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import SKIP_WARM_UP
|
||||
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from onyx.configs.model_configs import (
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
@@ -53,6 +52,7 @@ from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
@@ -1115,6 +1115,9 @@ def warm_up_cross_encoder(
|
||||
rerank_model_name: str,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
if SKIP_WARM_UP:
|
||||
return
|
||||
|
||||
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
||||
|
||||
reranking_model = RerankingModel(
|
||||
|
||||
@@ -113,7 +113,7 @@ def _check_tokenizer_cache(
|
||||
logger.info(
|
||||
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
tokenizer = _get_default_tokenizer()
|
||||
|
||||
_TOKENIZER_CACHE[id_tuple] = tokenizer
|
||||
|
||||
@@ -150,7 +150,15 @@ def _try_initialize_tokenizer(
|
||||
return None
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer | None = None
|
||||
|
||||
|
||||
def _get_default_tokenizer() -> BaseTokenizer:
|
||||
"""Lazy-load the default tokenizer to avoid loading it at module import time."""
|
||||
global _DEFAULT_TOKENIZER
|
||||
if _DEFAULT_TOKENIZER is None:
|
||||
_DEFAULT_TOKENIZER = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
@@ -163,7 +171,7 @@ def get_tokenizer(
|
||||
logger.debug(
|
||||
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||
)
|
||||
return _DEFAULT_TOKENIZER
|
||||
return _get_default_tokenizer()
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
|
||||
@@ -171,6 +171,14 @@ written as a list of one question.
|
||||
""",
|
||||
DRPath.CLOSER.value: f"""if the tool is {CLOSER}, the list of questions should simply be \
|
||||
['Answer the original question with the information you have.'].
|
||||
""",
|
||||
DRPath.IMAGE_GENERATION.value: """
|
||||
if the tool is Image Generation, respond with a list that contains exactly one JSON object
|
||||
string describing the tool call. The JSON must include a "prompt" field with the text to
|
||||
render. When the user specifies or implies an orientation, also include a "shape" field whose
|
||||
value is one of "square", "landscape", or "portrait" (use "landscape" for wide/horizontal
|
||||
requests and "portrait" for tall/vertical ones). Example: {"prompt": "Create a poster of a
|
||||
coral reef", "shape": "landscape"}. Do not surround the JSON with backticks or narration.
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# NOTE No longer used. This needs to be revisited later.
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from onyx.llm.utils import message_generator_to_string_generator
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.constants import ANSWERABLE_PAT
|
||||
from onyx.prompts.constants import THOUGHT_PAT
|
||||
from onyx.prompts.query_validation import ANSWERABLE_PROMPT
|
||||
from onyx.server.query_and_chat.models import QueryValidationResponse
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": ANSWERABLE_PROMPT.format(user_query=user_query),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def extract_answerability_reasoning(model_raw: str) -> str:
|
||||
reasoning_match = re.search(
|
||||
f"{THOUGHT_PAT.upper()}(.*?){ANSWERABLE_PAT.upper()}", model_raw, re.DOTALL
|
||||
)
|
||||
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
|
||||
return reasoning_text
|
||||
|
||||
|
||||
def extract_answerability_bool(model_raw: str) -> bool:
|
||||
answerable_match = re.search(f"{ANSWERABLE_PAT.upper()}(.+)", model_raw)
|
||||
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
|
||||
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
|
||||
return answerable
|
||||
|
||||
|
||||
def get_query_answerability(
|
||||
user_query: str, skip_check: bool = False
|
||||
) -> tuple[str, bool]:
|
||||
if skip_check:
|
||||
return "Query Answerability Evaluation feature is turned off", True
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
return "Generative AI is turned off - skipping check", True
|
||||
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
return reasoning, answerable
|
||||
|
||||
|
||||
def stream_query_answerability(
|
||||
user_query: str, skip_check: bool = False
|
||||
) -> Iterator[str]:
|
||||
if skip_check:
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
reasoning="Query Answerability Evaluation feature is turned off",
|
||||
answerable=True,
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
reasoning="Generative AI is turned off - skipping check",
|
||||
answerable=True,
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
try:
|
||||
tokens = message_generator_to_string_generator(llm.stream(filled_llm_prompt))
|
||||
reasoning_pat_found = False
|
||||
model_output = ""
|
||||
hold_answerable = ""
|
||||
for token in tokens:
|
||||
model_output = model_output + token
|
||||
|
||||
if ANSWERABLE_PAT.upper() in model_output:
|
||||
continue
|
||||
|
||||
if not reasoning_pat_found and THOUGHT_PAT.upper() in model_output:
|
||||
reasoning_pat_found = True
|
||||
reason_ind = model_output.find(THOUGHT_PAT.upper())
|
||||
remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :]
|
||||
if remaining:
|
||||
yield get_json_line(
|
||||
OnyxAnswerPiece(answer_piece=remaining).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
if reasoning_pat_found:
|
||||
hold_answerable = hold_answerable + token
|
||||
if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]:
|
||||
continue
|
||||
yield get_json_line(
|
||||
OnyxAnswerPiece(answer_piece=hold_answerable).model_dump()
|
||||
)
|
||||
hold_answerable = ""
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
reasoning=reasoning, answerable=answerable
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.model_dump())
|
||||
logger.exception("Failed to validate Query")
|
||||
return
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
@@ -123,6 +122,8 @@ def generate_starter_messages(
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ def seed_initial_documents(
|
||||
"base_url": "https://docs.onyx.app/",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
refresh_freq=None, # Never refresh by default
|
||||
refresh_freq=3600, # 1 hour
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_dir
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_package
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +44,8 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
# Import submodules using package-based discovery to avoid sys.path mutations
|
||||
oauth_connectors = find_all_subclasses_in_package(
|
||||
cast(type[OAuthConnector], OAuthConnector), "onyx.connectors"
|
||||
)
|
||||
|
||||
|
||||
@@ -1218,7 +1218,10 @@ def _upsert_mcp_server(
|
||||
|
||||
logger.info(f"Created new MCP server '{request.name}' with ID {mcp_server.id}")
|
||||
|
||||
if not changing_connection_config:
|
||||
if (
|
||||
not changing_connection_config
|
||||
or request.auth_type == MCPAuthenticationType.NONE
|
||||
):
|
||||
return mcp_server
|
||||
|
||||
# Create connection configs
|
||||
|
||||
@@ -162,7 +162,7 @@ def unlink_user_file_from_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
@@ -210,7 +210,7 @@ def link_user_file_to_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -23,8 +24,8 @@ router = APIRouter(prefix="/gpts")
|
||||
|
||||
def time_ago(dt: datetime) -> str:
|
||||
# Calculate time difference
|
||||
now = datetime.now()
|
||||
diff = now - dt
|
||||
now = datetime.now(timezone.utc)
|
||||
diff = now - dt.astimezone(timezone.utc)
|
||||
|
||||
# Convert difference to minutes
|
||||
minutes = diff.total_seconds() / 60
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import boto3
|
||||
import httpx
|
||||
from botocore.exceptions import BotoCoreError
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
@@ -11,10 +12,12 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
@@ -27,8 +30,8 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.llm_provider_options import BEDROCK_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import get_bedrock_model_names
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
@@ -40,6 +43,9 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -457,7 +463,7 @@ def get_bedrock_available_models(
|
||||
|
||||
# Keep only models we support (compatibility with litellm)
|
||||
filtered = sorted(
|
||||
[model for model in candidates if model in BEDROCK_MODEL_NAMES],
|
||||
[model for model in candidates if model in get_bedrock_model_names()],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -474,3 +480,100 @@ def get_bedrock_available_models(
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Unexpected error fetching Bedrock models: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _get_ollama_available_model_names(api_base: str) -> set[str]:
|
||||
"""Fetch available model names from Ollama server."""
|
||||
tags_url = f"{api_base}/api/tags"
|
||||
try:
|
||||
response = httpx.get(tags_url, timeout=5.0)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch Ollama models: {e}",
|
||||
)
|
||||
|
||||
models = response_json.get("models", [])
|
||||
return {model.get("name") for model in models if model.get("name")}
|
||||
|
||||
|
||||
@admin_router.post("/ollama/available-models")
|
||||
def get_ollama_available_models(
|
||||
request: OllamaModelsRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[OllamaFinalModelResponse]:
|
||||
"""Fetch the list of available models from an Ollama server."""
|
||||
|
||||
cleaned_api_base = request.api_base.strip().rstrip("/")
|
||||
if not cleaned_api_base:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="API base URL is required to fetch Ollama models."
|
||||
)
|
||||
|
||||
model_names = _get_ollama_available_model_names(cleaned_api_base)
|
||||
if not model_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your Ollama server",
|
||||
)
|
||||
|
||||
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
|
||||
show_url = f"{cleaned_api_base}/api/show"
|
||||
|
||||
for model_name in model_names:
|
||||
context_limit: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
try:
|
||||
show_response = httpx.post(
|
||||
show_url,
|
||||
json={"model": model_name},
|
||||
timeout=5.0,
|
||||
)
|
||||
show_response.raise_for_status()
|
||||
show_response_json = show_response.json()
|
||||
|
||||
# Parse the response into the expected format
|
||||
ollama_model_details = OllamaModelDetails.model_validate(show_response_json)
|
||||
|
||||
# Check if this model supports completion/chat
|
||||
if not ollama_model_details.supports_completion():
|
||||
continue
|
||||
|
||||
# Optimistically access. Context limit is stored as "model_architecture.context" = int
|
||||
architecture = ollama_model_details.model_info.get(
|
||||
"general.architecture", ""
|
||||
)
|
||||
context_limit = ollama_model_details.model_info.get(
|
||||
architecture + ".context_length", None
|
||||
)
|
||||
supports_image_input = ollama_model_details.supports_image_input()
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
"Invalid model details from Ollama server",
|
||||
extra={"model": model_name, "validation_error": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to fetch Ollama model details",
|
||||
extra={"model": model_name, "error": str(e)},
|
||||
)
|
||||
|
||||
# If we fail at any point attempting to extract context limit,
|
||||
# still allow this model to be used with a fallback max context size
|
||||
if not context_limit:
|
||||
context_limit = GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
if not supports_image_input:
|
||||
supports_image_input = False
|
||||
|
||||
all_models_with_context_size_and_vision.append(
|
||||
OllamaFinalModelResponse(
|
||||
name=model_name,
|
||||
max_input_tokens=context_limit,
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
)
|
||||
|
||||
return all_models_with_context_size_and_vision
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -138,8 +139,9 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name: str
|
||||
is_visible: bool | None = False
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -149,12 +151,13 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
is_visible: bool | None = False
|
||||
is_visible: bool
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool
|
||||
|
||||
@@ -196,3 +199,28 @@ class BedrockModelsRequest(BaseModel):
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_bearer_token_bedrock: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OllamaModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
|
||||
|
||||
class OllamaFinalModelResponse(BaseModel):
|
||||
name: str
|
||||
max_input_tokens: int
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel):
|
||||
"""Response model for Ollama /api/show endpoint"""
|
||||
|
||||
model_info: dict[str, Any]
|
||||
capabilities: list[str] = []
|
||||
|
||||
def supports_completion(self) -> bool:
|
||||
"""Check if this model supports completion/chat"""
|
||||
return "completion" in self.capabilities
|
||||
|
||||
def supports_image_input(self) -> bool:
|
||||
"""Check if this model supports image input"""
|
||||
return "vision" in self.capabilities
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -14,6 +16,7 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -297,6 +300,43 @@ def list_all_users(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/manage/users/download")
|
||||
def download_users_csv(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""Download all users as a CSV file."""
|
||||
# Get all users from the database
|
||||
users = get_all_users(db_session)
|
||||
|
||||
# Create CSV content in memory
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
|
||||
# Write CSV header
|
||||
writer.writerow(["Email", "Role", "Status"])
|
||||
|
||||
# Write user data
|
||||
for user in users:
|
||||
writer.writerow(
|
||||
[
|
||||
user.email,
|
||||
user.role.value if user.role else "",
|
||||
"Active" if user.is_active else "Inactive",
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare the CSV content for download
|
||||
csv_content = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(csv_content.encode("utf-8")),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": "attachment;"},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/manage/admin/users")
|
||||
def bulk_invite_users(
|
||||
emails: list[str] = Body(..., embed=True),
|
||||
|
||||
@@ -56,8 +56,8 @@ def validate_user_token(user_token: str | None) -> None:
|
||||
Raises:
|
||||
HTTPException: If the token is invalid or missing required fields
|
||||
"""
|
||||
if user_token is None:
|
||||
# user_token is optional, so None is valid
|
||||
if not user_token:
|
||||
# user_token is optional, so None or empty string is valid
|
||||
return
|
||||
|
||||
if not user_token.startswith(SLACK_USER_TOKEN_PREFIX):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user