Compare commits

...

2 Commits

Author SHA1 Message Date
Wenxi
b36910240d chore: Hotfix v2.0.0-beta.2 (#5658)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
2025-10-07 18:30:48 -07:00
Wenxi
488b27ba04 chore: hotfix v2.0.0 beta.1 (#5616)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
2025-10-07 17:08:17 -07:00
503 changed files with 15073 additions and 17650 deletions

View File

@@ -8,9 +8,9 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
build-and-push:
@@ -33,7 +33,16 @@ jobs:
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout code
uses: actions/checkout@v4
@@ -46,7 +55,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -119,7 +129,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -11,8 +11,8 @@ env:
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
@@ -145,6 +145,15 @@ jobs:
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -157,11 +166,16 @@ jobs:
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@v3

View File

@@ -7,7 +7,10 @@ on:
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
DEPLOYMENT: standalone
jobs:
@@ -45,6 +48,15 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
@@ -57,7 +69,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -126,7 +139,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -25,9 +25,11 @@ jobs:
- name: Add required Helm repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add keda https://kedacore.github.io/charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Build chart dependencies

View File

@@ -20,6 +20,7 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -65,35 +65,45 @@ jobs:
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Pre-pull critical images
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
- name: Pre-pull required images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling critical images to avoid timeout ==="
# Get kind cluster name
echo "=== Pre-pulling required images to avoid timeout ==="
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
# Pre-pull images that are likely to be used
echo "Pre-pulling PostgreSQL image..."
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
echo "Pre-pulling Redis image..."
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
echo "Pre-pulling Onyx images..."
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "Failed to pull $image"
fi
done
echo "=== Images loaded into Kind cluster ==="
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
@@ -149,6 +159,7 @@ jobs:
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
@@ -156,8 +167,10 @@ jobs:
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.primary.persistence.enabled=false \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
@@ -173,8 +186,16 @@ jobs:
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
echo "=== Installation completed successfully ==="
CT_EXIT=$?
set -e
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
else
echo "=== Installation completed successfully ==="
fi
kubectl get pods --all-namespaces
- name: Post-install verification
@@ -199,7 +220,7 @@ jobs:
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so

View File

@@ -22,9 +22,11 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -131,6 +133,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -158,6 +161,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -191,6 +195,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests:
needs:
@@ -337,9 +342,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -19,9 +19,11 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -128,6 +130,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -155,6 +158,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -188,6 +192,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests-mit:
needs:
@@ -334,9 +339,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -20,11 +20,13 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
# Gong
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}

View File

@@ -34,8 +34,6 @@ repos:
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
additional_dependencies:
- prettier
- repo: local
hooks:

View File

@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
### Contributing Code
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
and
[Discord](https://discord.gg/TDJ59cGV2X).
[Discord](https://discord.gg/4NA5SbzrWb).
We would love to see you there!
@@ -105,6 +101,11 @@ pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Fix vscode/cursor auto-imports:
```bash
pip install -e .
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
@@ -117,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `onyx/web` run:
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
to manage your Node installations. Once installed, you can run
```bash
nvm install 22 && nvm use 22`
node -v # verify your active version
```
Navigate to `onyx/web` and run:
```bash
npm i
@@ -129,8 +137,6 @@ npm i
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
@@ -150,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
Re-stage your changes and commit again.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
**We highly recommend using VSCode debugger for development.**
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.

View File

@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
Only run this if you are okay with wiping your data.
## Features
- Hot reload is enabled for the web server and API servers

View File

@@ -0,0 +1,37 @@
"""Add image input support to model config
Revision ID: 64bd5677aeb6
Revises: b30353be4eec
Create Date: 2025-09-28 15:48:12.003612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "64bd5677aeb6"
down_revision = "b30353be4eec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"model_configuration",
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
)
# Seems to be left over from when model visibility was introduced and a nullable field.
# Set any null is_visible values to False
connection = op.get_bind()
connection.execute(
sa.text(
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
)
)
def downgrade() -> None:
op.drop_column("model_configuration", "supports_image_input")

View File

@@ -124,9 +124,9 @@ def get_space_permission(
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"No permissions found for space '{space_key}'. This is very unlikely "
"to be correct and is more likely caused by an access token with "
"insufficient permissions. Make sure that the access token has Admin "
f"permissions for space '{space_key}'"
)

View File

@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
else 0.0
)
return gmail_connector.retrieve_all_slim_documents(
return gmail_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
else 0.0
)
return google_drive_connector.retrieve_all_slim_documents(
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warn(
logger.warning(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warn(
logger.warning(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue

View File

@@ -105,7 +105,9 @@ def _get_slack_document_access(
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:

View File

@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ def generic_doc_sync(
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
slim_connector: SlimConnectorWithPermSync,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -40,7 +40,7 @@ def generic_doc_sync(
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -37,6 +37,19 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Azure AD / Entra ID often returns the email attribute under different keys.
# Keep a list of common variations so we can fall back gracefully if the IdP
# does not send the plain "email" attribute name.
EMAIL_ATTRIBUTE_KEYS = {
"email",
"emailaddress",
"mail",
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/mail",
"http://schemas.microsoft.com/identity/claims/emailaddress",
}
EMAIL_ATTRIBUTE_KEYS_LOWER = {key.lower() for key in EMAIL_ATTRIBUTE_KEYS}
async def upsert_saml_user(email: str) -> User:
"""
@@ -204,16 +217,37 @@ async def _process_saml_callback(
detail=detail,
)
user_email = auth.get_attribute("email")
if not user_email:
detail = "SAML is not set up correctly, email attribute must be provided."
logger.error(detail)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
user_email: str | None = None
user_email = user_email[0]
# The OneLogin toolkit normalizes attribute keys, but still performs a
# case-sensitive lookup. Try the common keys first and then fall back to a
# case-insensitive scan of all returned attributes.
for attribute_key in EMAIL_ATTRIBUTE_KEYS:
attribute_values = auth.get_attribute(attribute_key)
if attribute_values:
user_email = attribute_values[0]
break
if not user_email:
# Fallback: perform a case-insensitive lookup across all attributes in
# case the IdP sent the email claim with a different capitalization.
attributes = auth.get_attributes()
for key, values in attributes.items():
if key.lower() in EMAIL_ATTRIBUTE_KEYS_LOWER:
if values:
user_email = values[0]
break
if not user_email:
detail = "SAML is not set up correctly, email attribute must be provided."
logger.error(detail)
logger.debug(
"Received SAML attributes without email: %s",
list(attributes.keys()),
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
user = await upsert_saml_user(email=user_email)

View File

@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import get_anthropic_model_names
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
max_input_tokens=None,
)
for name in ANTHROPIC_MODEL_NAMES
for name in get_anthropic_model_names()
],
api_key_changed=True,
)

View File

@@ -1,14 +1,12 @@
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
global _CONNECTOR_CLASSIFIER_TOKENIZER
from transformers import AutoTokenizer, PreTrainedTokenizer
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
from transformers import AutoTokenizer, PreTrainedTokenizer
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
@@ -141,7 +148,9 @@ def get_local_intent_model(
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
) -> "SetFitModel":
from setfit import SetFitModel
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
@@ -179,7 +188,7 @@ def get_local_information_content_model(
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
tokenizer: "PreTrainedTokenizer",
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
@@ -401,7 +410,7 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore

View File

@@ -2,13 +2,11 @@ import asyncio
import time
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from litellm.exceptions import RateLimitError
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
@@ -20,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer
logger = setup_logger()
router = APIRouter(prefix="/encoder")
@@ -88,8 +89,10 @@ def get_embedding_model(
def get_local_reranking_model(
model_name: str,
) -> CrossEncoder:
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder # type: ignore
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
model = CrossEncoder(model_name)
@@ -207,6 +210,8 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
from litellm.exceptions import RateLimitError
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(

View File

@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SKIP_WARM_UP
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
if not SKIP_WARM_UP:
if not INDEXING_ONLY:
logger.notice("Warming up intent model for inference model server")
warm_up_intent_model()
else:
logger.notice(
"Warming up content information model for indexing model server"
)
warm_up_information_content_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
yield

View File

@@ -1,16 +1,20 @@
import json
import os
from typing import cast
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertModel # type: ignore
from transformers import DistilBertTokenizer # type: ignore
if TYPE_CHECKING:
from transformers import DistilBertConfig # type: ignore
class HybridClassifier(nn.Module):
def __init__(self) -> None:
from transformers import DistilBertConfig, DistilBertModel
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
class ConnectorClassifier(nn.Module):
def __init__(self, config: DistilBertConfig) -> None:
def __init__(self, config: "DistilBertConfig") -> None:
from transformers import DistilBertTokenizer, DistilBertModel
super().__init__()
self.config = config
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
from transformers import DistilBertConfig
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from typing import Any
from typing import cast
from braintrust import traced
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.process_llm_stream import (
BasicSearchProcessedStreamResults,
)
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
@@ -666,28 +670,30 @@ def clarifier(
system_prompt_to_use = assistant_system_prompt
user_prompt_to_use = decision_prompt + assistant_task_prompt
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
@traced(name="clarifier stream and process", type="llm")
def stream_and_process() -> BasicSearchProcessedStreamResults:
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
return process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
full_response = stream_and_process()
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):

View File

@@ -1,3 +1,4 @@
import json
from datetime import datetime
from typing import cast
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,6 +64,29 @@ def image_generation(
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
@@ -69,7 +94,15 @@ def image_generation(
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
for tool_response in image_tool.run(prompt=branch_query):
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
@@ -95,6 +128,7 @@ def image_generation(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
@@ -107,15 +141,29 @@ def image_generation(
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {branch_query}"
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
return BranchUpdate(

View File

@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# Needed for PydanticType

View File

@@ -0,0 +1,147 @@
import json
from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(InternetSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
payload = {
"q": query,
}
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
return [
InternetSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
def contents(self, urls: list[str]) -> list[InternetContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> InternetContent:
try:
return self._get_webpage_content(url)
except Exception:
return InternetContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> InternetContent:
payload = {
"url": url,
}
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# 400 returned when serper cannot scrape
if response.status_code == 400:
return InternetContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
response.raise_for_status()
response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
return InternetContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -26,6 +26,7 @@ class InternetContent(BaseModel):
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
class InternetSearchProvider(ABC):

View File

@@ -1,13 +1,19 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> InternetSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,

View File

@@ -1,71 +0,0 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def call_tool(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
try:
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
tool_final_result = tool_runner.tool_final_result()
except Exception as e:
raise ToolCallException(
f"Error during tool call for {tool.display_name}: {e}"
) from e
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -1,354 +0,0 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.tools.models import QueryExpansions
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
# TODO: Add trimming logic
history_segments = []
for msg in prompt_builder.message_history:
if isinstance(msg, HumanMessage):
role = "User"
elif isinstance(msg, AIMessage):
role = "Assistant"
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
prompt_builder: AnswerPromptBuilder,
) -> str:
history_str = _create_history_str(prompt_builder)
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke([msg])
rephrased_query: str = cast(str, response.content)
return rephrased_query
def _expand_query_non_tool_calling_llm(
expanded_keyword_thread: TimeoutThread[str],
expanded_semantic_thread: TimeoutThread[str],
) -> QueryExpansions | None:
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
if keyword_expansion is None or semantic_expansion is None:
return None
return QueryExpansions(
keywords_expansions=[keyword_expansion],
semantic_expansions=[semantic_expansion],
)
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
# If we have override_kwargs, add them to the tool_args
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
)
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
)
):
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.prompt_builder.raw_user_query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.prompt_builder.raw_user_query,
)
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
expanded_keyword_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.KEYWORD,
prompt_builder,
)
expanded_semantic_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.SEMANTIC,
prompt_builder,
)
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
# dual keyword expansion needs to be added here for non-tool calling LLM case
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and expanded_keyword_thread
and expanded_semantic_thread
and tool.name == SearchTool._NAME
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
ind=0,
).ai_message_chunk
if tool_message is None:
raise ValueError("No tool message emitted by LLM")
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
if (
selected_tool.name == SearchTool._NAME
and expanded_keyword_thread
and expanded_semantic_thread
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and selected_tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
# TODO: this is a hack to handle the case where the expanded queries are not found.
# We should refactor this to be more robust.
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -1,17 +0,0 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(GraphConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
)

View File

@@ -5,11 +5,10 @@ from typing import Literal
from typing import Type
from typing import TypeVar
from braintrust import traced
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
@@ -29,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
@traced(name="stream llm", type="llm")
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
@@ -147,6 +147,7 @@ def invoke_llm_json(
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
from litellm.utils import get_supported_openai_params, supports_response_schema
# check if the model supports response_format: json_schema
supports_json = "response_format" in (

View File

@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import ONYX_DISCORD_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
<tr>
<td class="footer">
© {year} {application_name}. All rights reserved.
{slack_fragment}
{community_link_fragment}
</td>
</tr>
</table>
@@ -161,9 +161,9 @@ def build_html_email(
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
community_link_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
@@ -175,7 +175,7 @@ def build_html_email(
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
community_link_fragment=community_link_fragment,
year=datetime.now().year,
)

View File

@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: Iterator[list[Document]],
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
if isinstance(runnable_connector, SlimConnector):
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
if isinstance(runnable_connector, SlimConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
doc_batch_processing_func = (
rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(lambda x: x)
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
else lambda x: x
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():

View File

@@ -41,7 +41,7 @@ beat_task_templates: list[dict] = [
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
@@ -85,9 +85,9 @@ beat_task_templates: list[dict] = [
{
"name": "check-for-index-attempt-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
"schedule": timedelta(hours=1),
"schedule": timedelta(minutes=30),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},

View File

@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
tenant_id: str,
batch_num: int,
) -> None:
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic()
if tenant_id:

View File

@@ -579,6 +579,16 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
search_doc_map[doc_id] = []
search_doc_map[doc_id].append(sd)
task_logger.debug(
f"Built search doc map with {len(search_doc_map)} entries"
)
ids_preview = list(search_doc_map.keys())[:5]
task_logger.debug(
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
)
task_logger.debug(
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
)
# Process each UserFile and update matching SearchDocs
updated_count = 0
for uf in user_files:
@@ -586,9 +596,18 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
if doc_id.startswith("USER_FILE_CONNECTOR__"):
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
task_logger.debug(
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
)
if doc_id in search_doc_map:
search_docs = search_doc_map[doc_id]
task_logger.debug(
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
)
# Update the SearchDoc to use the UserFile's UUID
for search_doc in search_doc_map[doc_id]:
for search_doc in search_docs:
search_doc.document_id = str(uf.id)
db_session.add(search_doc)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
def get_old_index_attempts(
@@ -21,6 +22,10 @@ def get_old_index_attempts(
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
"""Clean up multiple index attempts"""
db_session.query(IndexAttemptError).filter(
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
).delete(synchronize_session=False)
db_session.query(IndexAttempt).filter(
IndexAttempt.id.in_(index_attempt_ids)
).delete(synchronize_session=False)

View File

@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
@@ -101,7 +102,6 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
from onyx.connectors.factory import instantiate_connector
task = attempt.connector_credential_pair.connector.input_type

View File

@@ -9,7 +9,6 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
@@ -113,6 +112,10 @@ logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
class PartialResponse(Protocol):
def __call__(
self,

View File

@@ -24,8 +24,6 @@ APP_PORT = 8080
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
#####
# User Facing Features Configs
#####
@@ -756,7 +754,7 @@ MAX_FEDERATED_CHUNKS = int(
# NOTE: this should only be enabled if you have purchased an enterprise license.
# if you're interested in an enterprise license, please reach out to us at
# founders@onyx.app OR message Chris Weaver or Yuhong Sun in the Onyx
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
# Discord community https://discord.gg/4NA5SbzrWb
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)

View File

@@ -90,6 +90,7 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)

View File

@@ -6,7 +6,7 @@ from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512

View File

@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
class BitbucketConnector(
CheckpointedConnector[BitbucketConnectorCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
):
"""Connector for indexing Bitbucket Cloud pull requests.
@@ -266,7 +266,7 @@ class BitbucketConnector(
"""Validate and deserialize a checkpoint instance from JSON."""
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -5,6 +5,7 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian.errors import ApiError # type: ignore
from requests.exceptions import HTTPError
from typing_extensions import override
@@ -41,6 +42,7 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -91,6 +93,7 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
class ConfluenceConnector(
CheckpointedConnector[ConfluenceCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
CredentialsConnector,
):
def __init__(
@@ -108,6 +111,7 @@ class ConfluenceConnector(
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
scoped_token: bool = False,
) -> None:
self.wiki_base = wiki_base
self.is_cloud = is_cloud
@@ -118,6 +122,7 @@ class ConfluenceConnector(
self.batch_size = batch_size
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self.scoped_token = scoped_token
self._confluence_client: OnyxConfluence | None = None
self._low_timeout_confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
@@ -195,6 +200,7 @@ class ConfluenceConnector(
is_cloud=self.is_cloud,
url=self.wiki_base,
credentials_provider=credentials_provider,
scoped_token=self.scoped_token,
)
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
@@ -207,6 +213,7 @@ class ConfluenceConnector(
url=self.wiki_base,
credentials_provider=credentials_provider,
timeout=3,
scoped_token=self.scoped_token,
)
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
@@ -558,7 +565,21 @@ class ConfluenceConnector(
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=False,
)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -568,12 +589,28 @@ class ConfluenceConnector(
Return 'slim' docs (IDs + minimal permission data).
Does not fetch actual text. Used primarily for incremental permission sync.
"""
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=True,
)
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
space_level_access_info: dict[str, ExternalAccess] = {}
if include_permissions:
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
def get_external_access(
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
@@ -600,8 +637,10 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=page_id,
external_access=get_external_access(
page_id, page_restrictions, page_ancestors
external_access=(
get_external_access(page_id, page_restrictions, page_ancestors)
if include_permissions
else None
),
)
)
@@ -636,8 +675,12 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=attachment_id,
external_access=get_external_access(
attachment_id, attachment_restrictions, []
external_access=(
get_external_access(
attachment_id, attachment_restrictions, []
)
if include_permissions
else None
),
)
)
@@ -648,10 +691,10 @@ class ConfluenceConnector(
if callback and callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
if callback:
callback.progress("retrieve_all_slim_documents", 1)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
yield doc_metadata_list
@@ -676,6 +719,14 @@ class ConfluenceConnector(
f"Unexpected error while validating Confluence settings: {e}"
)
if self.space:
try:
self.low_timeout_confluence_client.get_space(self.space)
except ApiError as e:
raise ConnectorValidationError(
"Invalid Confluence space key provided"
) from e
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
@@ -724,7 +775,7 @@ if __name__ == "__main__":
end = datetime.now().timestamp()
# Fetch all `SlimDocuments`.
for slim_doc in confluence_connector.retrieve_all_slim_documents():
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
print(slim_doc)
# Fetch all `Documents`.

View File

@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
@@ -87,16 +88,20 @@ class OnyxConfluence:
url: str,
credentials_provider: CredentialsProviderInterface,
timeout: int | None = None,
scoped_token: bool = False,
# should generally not be passed in, but making it overridable for
# easier testing
confluence_user_profiles_override: list[dict[str, str]] | None = (
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
),
) -> None:
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
url = scoped_url(url, "confluence") if scoped_token else url
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.scoped_token = scoped_token
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
@@ -218,6 +223,34 @@ class OnyxConfluence:
with self._credentials_provider:
credentials, _ = self._renew_credentials()
if self.scoped_token:
# v2 endpoint doesn't always work with scoped tokens, use v1
token = credentials["confluence_access_token"]
probe_url = f"{self.base_url}/rest/api/space?limit=1"
import requests
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
try:
r = requests.get(
probe_url,
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
r.raise_for_status()
except HTTPError as e:
if e.response.status_code == 403:
logger.warning(
"scoped token authenticated but not valid for probe endpoint (spaces)"
)
else:
if "WWW-Authenticate" in e.response.headers:
logger.warning(
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
)
logger.warning(f"Full error: {e.response.text}")
raise e
return
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
@@ -236,6 +269,7 @@ class OnyxConfluence:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
logger.info("running with cloud client")
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
@@ -304,7 +338,9 @@ class OnyxConfluence:
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
logger.info(
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
)
if self._is_cloud:
confluence = Confluence(
url=self._url,

View File

@@ -5,7 +5,10 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TypeVar
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
def is_atlassian_date_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)
def get_cloudId(base_url: str) -> str:
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
response = requests.get(tenant_info_url, timeout=10)
response.raise_for_status()
return response.json()["cloudId"]
def scoped_url(url: str, product: str) -> str:
parsed = urlparse(url)
base_url = parsed.scheme + "://" + parsed.netloc
cloud_id = get_cloudId(base_url)
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"

View File

@@ -1,3 +1,4 @@
import importlib
from typing import Any
from typing import Type
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.bitbucket.connector import BitbucketConnector
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
from onyx.connectors.gitbook.connector import GitbookConnector
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.gitlab.connector import GitlabConnector
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.gong.connector import GongConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.imap.connector import ImapConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.jira.connector import JiraConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.outline.connector import OutlineConnector
from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
from onyx.connectors.xenforo.connector import XenforoConnector
from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
pass
# Cache for already imported connector classes
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
"""Dynamically load and cache a connector class."""
if source in _connector_cache:
return _connector_cache[source]
if source not in CONNECTOR_CLASS_MAP:
raise ConnectorMissingException(f"Connector not found for source={source}")
mapping = CONNECTOR_CLASS_MAP[source]
try:
module = importlib.import_module(mapping.module_path)
connector_class = getattr(module, mapping.class_name)
_connector_cache[source] = connector_class
return connector_class
except (ImportError, AttributeError) as e:
raise ConnectorMissingException(
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
)
def _validate_connector_supports_input_type(
connector: Type[BaseConnector],
input_type: InputType | None,
source: DocumentSource,
) -> None:
"""Validate that a connector supports the requested input type."""
if input_type is None:
return
# Check each input type requirement separately for clarity
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
connector, LoadConnector
)
poll_unsupported = (
input_type == InputType.POLL
# Either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
)
event_unsupported = input_type == InputType.EVENT and not issubclass(
connector, EventConnector
)
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
def identify_connector_class(
source: DocumentSource,
input_type: InputType | None = None,
) -> Type[BaseConnector]:
connector_map = {
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
DocumentSource.GITLAB: GitlabConnector,
DocumentSource.GITBOOK: GitbookConnector,
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
DocumentSource.BOOKSTACK: BookstackConnector,
DocumentSource.OUTLINE: OutlineConnector,
DocumentSource.CONFLUENCE: ConfluenceConnector,
DocumentSource.JIRA: JiraConnector,
DocumentSource.PRODUCTBOARD: ProductboardConnector,
DocumentSource.SLAB: SlabConnector,
DocumentSource.NOTION: NotionConnector,
DocumentSource.ZULIP: ZulipConnector,
DocumentSource.GURU: GuruConnector,
DocumentSource.LINEAR: LinearConnector,
DocumentSource.HUBSPOT: HubSpotConnector,
DocumentSource.DOCUMENT360: Document360Connector,
DocumentSource.GONG: GongConnector,
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.DROPBOX: DropboxConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
DocumentSource.TEAMS: TeamsConnector,
DocumentSource.SALESFORCE: SalesforceConnector,
DocumentSource.DISCOURSE: DiscourseConnector,
DocumentSource.AXERO: AxeroConnector,
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.DISCORD: DiscordConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
DocumentSource.IMAP: ImapConnector,
DocumentSource.BITBUCKET: BitbucketConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
# Load the connector class using lazy loading
connector = _load_connector_class(source)
if isinstance(connector_by_source, dict):
if input_type is None:
# If not specified, default to most exhaustive update
connector = connector_by_source.get(InputType.LOAD_STATE)
else:
connector = connector_by_source.get(input_type)
else:
connector = connector_by_source
if connector is None:
raise ConnectorMissingException(f"Connector not found for source={source}")
# Validate connector supports the requested input_type
_validate_connector_supports_input_type(connector, input_type, source)
if any(
[
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector

View File

@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
def _get_userinfo(user: NamedUser) -> dict[str, str]:
def _safe_get(attr_name: str) -> str | None:
try:
return cast(str | None, getattr(user, attr_name))
except GithubException:
logger.debug(f"Error getting {attr_name} for user")
return None
return {
k: v
for k, v in {
"login": user.login,
"name": user.name,
"email": user.email,
"login": _safe_get("login"),
"name": _safe_get("name"),
"email": _safe_get("email"),
}.items()
if v is not None
}

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
@@ -232,7 +232,7 @@ def thread_to_document(
)
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -397,10 +397,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
@@ -431,7 +431,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1296,7 +1296,7 @@ class GoogleDriveConnector(
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
name: str
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""
Connector for loading data from Highspot.
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
description = item_details.get("description", "")
return title, description
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -1,14 +1,18 @@
import re
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
import requests
from hubspot import HubSpot # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
# Available HubSpot object types
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
HUBSPOT_PAGE_SIZE = 100
T = TypeVar("T")
logger = setup_logger()
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self._access_token = access_token
self._portal_id: str | None = None
self._rate_limiter = HubSpotRateLimiter()
# Set object types to fetch, default to all available types
if object_types is None:
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
"""Set the portal ID."""
self._portal_id = value
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
return self._rate_limiter.call(func, *args, **kwargs)
def _paginated_results(
self,
fetch_page: Callable[..., Any],
**kwargs: Any,
) -> Generator[Any, None, None]:
base_kwargs = dict(kwargs)
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
after: str | None = None
while True:
page_kwargs = base_kwargs.copy()
if after is not None:
page_kwargs["after"] = after
page = self._call_hubspot(fetch_page, **page_kwargs)
results = getattr(page, "results", [])
for result in results:
yield result
paging = getattr(page, "paging", None)
next_page = getattr(paging, "next", None) if paging else None
if next_page is None:
break
after = getattr(next_page, "after", None)
if after is None:
break
def _clean_html_content(self, html_content: str) -> str:
"""Clean HTML content and extract raw text"""
if not html_content:
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get associated objects for a given object"""
try:
associations = api_client.crm.associations.v4.basic_api.get_page(
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
object_type=from_object_type,
object_id=object_id,
to_object_type=to_object_type,
)
associated_objects = []
if associations.results:
object_ids = [assoc.to_object_id for assoc in associations.results]
object_ids = [assoc.to_object_id for assoc in associations_iter]
# Batch get the associated objects
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = api_client.crm.contacts.basic_api.get_by_id(
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
associated_objects: list[dict[str, Any]] = []
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = api_client.crm.companies.basic_api.get_by_id(
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.contacts.basic_api.get_by_id,
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = api_client.crm.deals.basic_api.get_by_id(
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.companies.basic_api.get_by_id,
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = api_client.crm.tickets.basic_api.get_by_id(
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.deals.basic_api.get_by_id,
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.tickets.basic_api.get_by_id,
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
return associated_objects
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get notes associated with a given object"""
try:
# Get associations to notes (engagement type)
associations = api_client.crm.associations.v4.basic_api.get_page(
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
object_type=object_type,
object_id=object_id,
to_object_type="notes",
)
associated_notes = []
if associations.results:
note_ids = [assoc.to_object_id for assoc in associations.results]
note_ids = [assoc.to_object_id for assoc in associations_iter]
# Batch get the associated notes
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
associated_notes = []
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = self._call_hubspot(
api_client.crm.objects.notes.basic_api.get_by_id,
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
return associated_notes
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_tickets = api_client.crm.tickets.get_all(
tickets_iter = self._paginated_results(
api_client.crm.tickets.basic_api.get_page,
properties=[
"subject",
"content",
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for ticket in all_tickets:
for ticket in tickets_iter:
updated_at = ticket.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_companies = api_client.crm.companies.get_all(
companies_iter = self._paginated_results(
api_client.crm.companies.basic_api.get_page,
properties=[
"name",
"domain",
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for company in all_companies:
for company in companies_iter:
updated_at = company.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_deals = api_client.crm.deals.get_all(
deals_iter = self._paginated_results(
api_client.crm.deals.basic_api.get_page,
properties=[
"dealname",
"amount",
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for deal in all_deals:
for deal in deals_iter:
updated_at = deal.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_contacts = api_client.crm.contacts.get_all(
contacts_iter = self._paginated_results(
api_client.crm.contacts.basic_api.get_page,
properties=[
"firstname",
"lastname",
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for contact in all_contacts:
for contact in contacts_iter:
updated_at = contact.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue

View File

@@ -0,0 +1,145 @@
from __future__ import annotations
import time
from collections.abc import Callable
from typing import Any
from typing import TypeVar
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
RateLimitTriedTooManyTimesError,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T")
# HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds)
# with a maximum of 190 requests, and a per-second limit of 19 requests.
_HUBSPOT_TEN_SECOND_LIMIT = 190
_HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds
_HUBSPOT_SECONDLY_LIMIT = 19
_HUBSPOT_SECONDLY_PERIOD = 1 # second
_DEFAULT_SLEEP_SECONDS = 10
_SLEEP_PADDING_SECONDS = 1.0
_MAX_RATE_LIMIT_RETRIES = 5
def _extract_header(headers: Any, key: str) -> str | None:
if headers is None:
return None
getter = getattr(headers, "get", None)
if callable(getter):
value = getter(key)
if value is not None:
return value
if isinstance(headers, dict):
value = headers.get(key)
if value is not None:
return value
return None
def is_rate_limit_error(exception: Exception) -> bool:
status = getattr(exception, "status", None)
if status == 429:
return True
headers = getattr(exception, "headers", None)
if headers is not None:
remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining")
if remaining == "0":
return True
secondly_remaining = _extract_header(
headers, "x-hubspot-ratelimit-secondly-remaining"
)
if secondly_remaining == "0":
return True
message = str(exception)
return "RATE_LIMIT" in message or "Too Many Requests" in message
def get_rate_limit_retry_delay_seconds(exception: Exception) -> float:
headers = getattr(exception, "headers", None)
retry_after = _extract_header(headers, "Retry-After")
if retry_after:
try:
return float(retry_after) + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse Retry-After header '%s' as float", retry_after
)
interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds")
if interval_ms:
try:
return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float",
interval_ms,
)
secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly")
if secondly_limit:
try:
per_second = max(float(secondly_limit), 1.0)
return (1.0 / per_second) + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse x-hubspot-ratelimit-secondly '%s' as float",
secondly_limit,
)
return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS
class HubSpotRateLimiter:
def __init__(
self,
*,
ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT,
ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD,
secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT,
secondly_period: int = _HUBSPOT_SECONDLY_PERIOD,
max_retries: int = _MAX_RATE_LIMIT_RETRIES,
) -> None:
self._max_retries = max_retries
@rate_limit_builder(max_calls=secondly_limit, period=secondly_period)
@rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period)
def _execute(callable_: Callable[[], T]) -> T:
return callable_()
self._execute = _execute
def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
attempts = 0
while True:
try:
return self._execute(lambda: func(*args, **kwargs))
except Exception as exc: # pylint: disable=broad-except
if not is_rate_limit_error(exc):
raise
attempts += 1
if attempts > self._max_retries:
raise RateLimitTriedTooManyTimesError(
"Exceeded configured HubSpot rate limit retries"
) from exc
wait_time = get_rate_limit_retry_delay_seconds(exc)
logger.notice(
"HubSpot rate limit reached. Sleeping %.2f seconds before retrying.",
wait_time,
)
time.sleep(wait_time)

View File

@@ -97,11 +97,20 @@ class PollConnector(BaseConnector):
raise NotImplementedError
# Slim connectors can retrieve just the ids and
# permission syncing information for connected documents
# Slim connectors retrieve just the ids of documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
def retrieve_all_slim_docs(
self,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError
# Slim connectors retrieve both the ids AND
# permission syncing information for connected documents
class SlimConnectorWithPermSync(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -25,11 +25,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.jira.access import get_project_permissions
from onyx.connectors.jira.utils import best_effort_basic_expert_info
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
@@ -247,7 +247,7 @@ def _perform_jql_search_v2(
def process_jira_issue(
jira_client: JIRA,
jira_base_url: str,
issue: Issue,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
@@ -281,7 +281,7 @@ def process_jira_issue(
)
return None
page_url = build_jira_url(jira_client, issue.key)
page_url = build_jira_url(jira_base_url, issue.key)
metadata_dict: dict[str, str | list[str]] = {}
people = set()
@@ -359,7 +359,10 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
offset: int | None = None
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
class JiraConnector(
CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint],
SlimConnectorWithPermSync,
):
def __init__(
self,
jira_base_url: str,
@@ -372,15 +375,23 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
# Custom JQL query to filter Jira issues
jql_query: str | None = None,
scoped_token: bool = False,
) -> None:
self.batch_size = batch_size
# dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net
# when making jira requests but still want correct links to issues in the UI.
# So, the user's base url is stored here, but converted to a scoped url when passed
# to the jira client.
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self.jql_query = jql_query
self.scoped_token = scoped_token
self._jira_client: JIRA | None = None
# Cache project permissions to avoid fetching them repeatedly across runs
self._project_permissions_cache: dict[str, Any] = {}
@property
def comment_email_blacklist(self) -> tuple:
@@ -399,10 +410,26 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
return ""
return f'"{self.jira_project}"'
def _get_project_permissions(self, project_key: str) -> Any:
"""Get project permissions with caching.
Args:
project_key: The Jira project key
Returns:
The external access permissions for the project
"""
if project_key not in self._project_permissions_cache:
self._project_permissions_cache[project_key] = get_project_permissions(
jira_client=self.jira_client, jira_project=project_key
)
return self._project_permissions_cache[project_key]
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
scoped_token=self.scoped_token,
)
return None
@@ -442,15 +469,37 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
) -> CheckpointOutput[JiraConnectorCheckpoint]:
jql = self._get_jql_query(start, end)
try:
return self._load_from_checkpoint(jql, checkpoint)
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=False
)
except Exception as e:
if is_atlassian_date_error(e):
jql = self._get_jql_query(start - ONE_HOUR, end)
return self._load_from_checkpoint(jql, checkpoint)
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=False
)
raise e
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraConnectorCheckpoint,
) -> CheckpointOutput[JiraConnectorCheckpoint]:
"""Load documents from checkpoint with permission information included."""
jql = self._get_jql_query(start, end)
try:
return self._load_from_checkpoint(jql, checkpoint, include_permissions=True)
except Exception as e:
if is_atlassian_date_error(e):
jql = self._get_jql_query(start - ONE_HOUR, end)
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=True
)
raise e
def _load_from_checkpoint(
self, jql: str, checkpoint: JiraConnectorCheckpoint
self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool
) -> CheckpointOutput[JiraConnectorCheckpoint]:
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.offset or 0
@@ -472,18 +521,25 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
issue_key = issue.key
try:
if document := process_jira_issue(
jira_client=self.jira_client,
jira_base_url=self.jira_base,
issue=issue,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
# Add permission information to the document if requested
if include_permissions:
project_key = get_jira_project_key_from_issue(issue=issue)
if project_key:
document.external_access = self._get_project_permissions(
project_key
)
yield document
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_jira_url(self.jira_client, issue_key),
document_link=build_jira_url(self.jira_base, issue_key),
),
failure_message=f"Failed to process Jira issue: {str(e)}",
exception=e,
@@ -515,7 +571,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
# if we didn't retrieve a full batch, we're done
checkpoint.has_more = current_offset - starting_offset == page_size
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -534,6 +590,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
prev_offset = 0
current_offset = 0
slim_doc_batch = []
while checkpoint.has_more:
for issue in _perform_jql_search(
jira_client=self.jira_client,
@@ -550,13 +607,12 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
continue
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
id = build_jira_url(self.jira_client, issue_key)
id = build_jira_url(self.jira_base, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
external_access=get_project_permissions(
jira_client=self.jira_client, jira_project=project_key
),
external_access=self._get_project_permissions(project_key),
)
)
current_offset += 1
@@ -701,7 +757,7 @@ if __name__ == "__main__":
start = 0
end = datetime.now().timestamp()
for slim_doc in connector.retrieve_all_slim_documents(
for slim_doc in connector.retrieve_all_slim_docs_perm_sync(
start=start,
end=end,
):

View File

@@ -10,6 +10,7 @@ from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.logger import setup_logger
@@ -74,11 +75,18 @@ def extract_text_from_adf(adf: dict | None) -> str:
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_url(jira_base_url: str, issue_key: str) -> str:
"""
Get the url used to access an issue in the UI.
"""
return f"{jira_base_url}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
def build_jira_client(
credentials: dict[str, Any], jira_base: str, scoped_token: bool = False
) -> JIRA:
jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:

View File

@@ -0,0 +1,208 @@
"""Registry mapping for connector classes."""
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
class ConnectorMapping(BaseModel):
module_path: str
class_name: str
# Mapping of DocumentSource to connector details for lazy loading
CONNECTOR_CLASS_MAP = {
DocumentSource.WEB: ConnectorMapping(
module_path="onyx.connectors.web.connector",
class_name="WebConnector",
),
DocumentSource.FILE: ConnectorMapping(
module_path="onyx.connectors.file.connector",
class_name="LocalFileConnector",
),
DocumentSource.SLACK: ConnectorMapping(
module_path="onyx.connectors.slack.connector",
class_name="SlackConnector",
),
DocumentSource.GITHUB: ConnectorMapping(
module_path="onyx.connectors.github.connector",
class_name="GithubConnector",
),
DocumentSource.GMAIL: ConnectorMapping(
module_path="onyx.connectors.gmail.connector",
class_name="GmailConnector",
),
DocumentSource.GITLAB: ConnectorMapping(
module_path="onyx.connectors.gitlab.connector",
class_name="GitlabConnector",
),
DocumentSource.GITBOOK: ConnectorMapping(
module_path="onyx.connectors.gitbook.connector",
class_name="GitbookConnector",
),
DocumentSource.GOOGLE_DRIVE: ConnectorMapping(
module_path="onyx.connectors.google_drive.connector",
class_name="GoogleDriveConnector",
),
DocumentSource.BOOKSTACK: ConnectorMapping(
module_path="onyx.connectors.bookstack.connector",
class_name="BookstackConnector",
),
DocumentSource.OUTLINE: ConnectorMapping(
module_path="onyx.connectors.outline.connector",
class_name="OutlineConnector",
),
DocumentSource.CONFLUENCE: ConnectorMapping(
module_path="onyx.connectors.confluence.connector",
class_name="ConfluenceConnector",
),
DocumentSource.JIRA: ConnectorMapping(
module_path="onyx.connectors.jira.connector",
class_name="JiraConnector",
),
DocumentSource.PRODUCTBOARD: ConnectorMapping(
module_path="onyx.connectors.productboard.connector",
class_name="ProductboardConnector",
),
DocumentSource.SLAB: ConnectorMapping(
module_path="onyx.connectors.slab.connector",
class_name="SlabConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",
),
DocumentSource.ZULIP: ConnectorMapping(
module_path="onyx.connectors.zulip.connector",
class_name="ZulipConnector",
),
DocumentSource.GURU: ConnectorMapping(
module_path="onyx.connectors.guru.connector",
class_name="GuruConnector",
),
DocumentSource.LINEAR: ConnectorMapping(
module_path="onyx.connectors.linear.connector",
class_name="LinearConnector",
),
DocumentSource.HUBSPOT: ConnectorMapping(
module_path="onyx.connectors.hubspot.connector",
class_name="HubSpotConnector",
),
DocumentSource.DOCUMENT360: ConnectorMapping(
module_path="onyx.connectors.document360.connector",
class_name="Document360Connector",
),
DocumentSource.GONG: ConnectorMapping(
module_path="onyx.connectors.gong.connector",
class_name="GongConnector",
),
DocumentSource.GOOGLE_SITES: ConnectorMapping(
module_path="onyx.connectors.google_site.connector",
class_name="GoogleSitesConnector",
),
DocumentSource.ZENDESK: ConnectorMapping(
module_path="onyx.connectors.zendesk.connector",
class_name="ZendeskConnector",
),
DocumentSource.LOOPIO: ConnectorMapping(
module_path="onyx.connectors.loopio.connector",
class_name="LoopioConnector",
),
DocumentSource.DROPBOX: ConnectorMapping(
module_path="onyx.connectors.dropbox.connector",
class_name="DropboxConnector",
),
DocumentSource.SHAREPOINT: ConnectorMapping(
module_path="onyx.connectors.sharepoint.connector",
class_name="SharepointConnector",
),
DocumentSource.TEAMS: ConnectorMapping(
module_path="onyx.connectors.teams.connector",
class_name="TeamsConnector",
),
DocumentSource.SALESFORCE: ConnectorMapping(
module_path="onyx.connectors.salesforce.connector",
class_name="SalesforceConnector",
),
DocumentSource.DISCOURSE: ConnectorMapping(
module_path="onyx.connectors.discourse.connector",
class_name="DiscourseConnector",
),
DocumentSource.AXERO: ConnectorMapping(
module_path="onyx.connectors.axero.connector",
class_name="AxeroConnector",
),
DocumentSource.CLICKUP: ConnectorMapping(
module_path="onyx.connectors.clickup.connector",
class_name="ClickupConnector",
),
DocumentSource.MEDIAWIKI: ConnectorMapping(
module_path="onyx.connectors.mediawiki.wiki",
class_name="MediaWikiConnector",
),
DocumentSource.WIKIPEDIA: ConnectorMapping(
module_path="onyx.connectors.wikipedia.connector",
class_name="WikipediaConnector",
),
DocumentSource.ASANA: ConnectorMapping(
module_path="onyx.connectors.asana.connector",
class_name="AsanaConnector",
),
DocumentSource.S3: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.R2: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.OCI_STORAGE: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.XENFORO: ConnectorMapping(
module_path="onyx.connectors.xenforo.connector",
class_name="XenforoConnector",
),
DocumentSource.DISCORD: ConnectorMapping(
module_path="onyx.connectors.discord.connector",
class_name="DiscordConnector",
),
DocumentSource.FRESHDESK: ConnectorMapping(
module_path="onyx.connectors.freshdesk.connector",
class_name="FreshdeskConnector",
),
DocumentSource.FIREFLIES: ConnectorMapping(
module_path="onyx.connectors.fireflies.connector",
class_name="FirefliesConnector",
),
DocumentSource.EGNYTE: ConnectorMapping(
module_path="onyx.connectors.egnyte.connector",
class_name="EgnyteConnector",
),
DocumentSource.AIRTABLE: ConnectorMapping(
module_path="onyx.connectors.airtable.airtable_connector",
class_name="AirtableConnector",
),
DocumentSource.HIGHSPOT: ConnectorMapping(
module_path="onyx.connectors.highspot.connector",
class_name="HighspotConnector",
),
DocumentSource.IMAP: ConnectorMapping(
module_path="onyx.connectors.imap.connector",
class_name="ImapConnector",
),
DocumentSource.BITBUCKET: ConnectorMapping(
module_path="onyx.connectors.bitbucket.connector",
class_name="BitbucketConnector",
),
# just for integration tests
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
module_path="onyx.connectors.mock_connector.connector",
class_name="MockConnector",
),
}

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -151,7 +151,7 @@ def _validate_custom_query_config(config: dict[str, Any]) -> None:
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""Approach outline
Goal
@@ -1119,7 +1119,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
with tempfile.TemporaryDirectory() as temp_dir:
return self._delta_sync(temp_dir, start, end)
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -41,7 +41,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -73,7 +73,8 @@ class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
Args:
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests
or https://danswerai.sharepoint.com/teams/team-name)
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
@@ -672,7 +673,7 @@ def _convert_sitepage_to_slim_document(
class SharepointConnector(
SlimConnector,
SlimConnectorWithPermSync,
CheckpointedConnectorWithPermSync[SharepointConnectorCheckpoint],
):
def __init__(
@@ -703,9 +704,11 @@ class SharepointConnector(
# Ensure sites are sharepoint urls
for site_url in self.sites:
if not site_url.startswith("https://") or "/sites/" not in site_url:
if not site_url.startswith("https://") or not (
"/sites/" in site_url or "/teams/" in site_url
):
raise ConnectorValidationError(
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site)"
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
)
@property
@@ -720,10 +723,17 @@ class SharepointConnector(
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
site_type_index = None
if "sites" in parts:
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
remaining_parts = parts[sites_index + 2 :]
site_type_index = parts.index("sites")
elif "teams" in parts:
site_type_index = parts.index("teams")
if site_type_index is not None:
# Extract the base site URL (up to and including the site/team name)
site_url = "/".join(parts[: site_type_index + 2])
remaining_parts = parts[site_type_index + 2 :]
# Extract drive name and folder path
if remaining_parts:
@@ -745,7 +755,9 @@ class SharepointConnector(
)
)
else:
logger.warning(f"Site URL '{url}' is not a valid Sharepoint URL")
logger.warning(
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)"
)
return site_data_list
def _get_drive_items_for_drive_name(
@@ -1597,7 +1609,7 @@ class SharepointConnector(
) -> SharepointConnectorCheckpoint:
return SharepointConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -164,7 +164,7 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(
self,
base_url: str,
@@ -239,7 +239,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -42,7 +42,7 @@ from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -581,7 +581,7 @@ def _process_message(
class SlackConnector(
SlimConnector,
SlimConnectorWithPermSync,
CredentialsConnector,
CheckpointedConnectorWithPermSync[SlackCheckpoint],
):
@@ -732,7 +732,7 @@ class SlackConnector(
self.text_cleaner = SlackTextCleaner(client=self.client)
self.credentials_provider = credentials_provider
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -22,7 +22,7 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -51,7 +51,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
class TeamsConnector(
CheckpointedConnector[TeamsCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
@@ -228,9 +228,9 @@ class TeamsConnector(
has_more=bool(todos),
)
# impls for SlimConnector
# impls for SlimConnectorWithPermSync
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -572,7 +572,7 @@ if __name__ == "__main__":
)
teams_connector.validate_connector_settings()
for slim_doc in teams_connector.retrieve_all_slim_documents():
for slim_doc in teams_connector.retrieve_all_slim_docs_perm_sync():
...
for doc in load_everything_from_checkpoint_connector(

View File

@@ -219,6 +219,25 @@ def is_valid_url(url: str) -> bool:
return False
def _same_site(base_url: str, candidate_url: str) -> bool:
base, candidate = urlparse(base_url), urlparse(candidate_url)
base_netloc = base.netloc.lower().removeprefix("www.")
candidate_netloc = candidate.netloc.lower().removeprefix("www.")
if base_netloc != candidate_netloc:
return False
base_path = (base.path or "/").rstrip("/")
if base_path in ("", "/"):
return True
candidate_path = candidate.path or "/"
if candidate_path == base_path:
return True
boundary = f"{base_path}/"
return candidate_path.startswith(boundary)
def get_internal_links(
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
) -> set[str]:
@@ -239,7 +258,7 @@ def get_internal_links(
# Relative path handling
href = urljoin(url, href)
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
if _same_site(base_url, href):
internal_links.add(href)
return internal_links

View File

@@ -26,7 +26,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
@@ -376,7 +376,7 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
class ZendeskConnector(
SlimConnector, CheckpointedConnector[ZendeskConnectorCheckpoint]
SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
):
def __init__(
self,
@@ -565,7 +565,7 @@ class ZendeskConnector(
)
return checkpoint
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -2,7 +2,6 @@ import string
from collections.abc import Callable
from uuid import UUID
import nltk # type:ignore
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
@@ -61,6 +60,8 @@ def _dedupe_chunks(
def download_nltk_data() -> None:
import nltk # type: ignore[import-untyped]
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use

View File

@@ -2,8 +2,6 @@ import string
from collections.abc import Sequence
from typing import TypeVar
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.chat.models import SectionRelevancePiece
@@ -119,6 +117,9 @@ def inference_section_from_chunks(
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
try:
# Re-tokenize using the NLTK tokenizer for better matching
query = " ".join(keywords)

View File

@@ -112,6 +112,7 @@ def upsert_llm_provider(
name=model_configuration.name,
is_visible=model_configuration.is_visible,
max_input_tokens=model_configuration.max_input_tokens,
supports_image_input=model_configuration.supports_image_input,
)
.on_conflict_do_nothing()
)

View File

@@ -2353,6 +2353,8 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
llm_provider: Mapped["LLMProvider"] = relationship(
"LLMProvider",
back_populates="model_configurations",

View File

@@ -1,15 +1,52 @@
"""Factory for creating federated connector instances."""
import importlib
from typing import Any
from typing import Type
from onyx.configs.constants import FederatedConnectorSource
from onyx.federated_connectors.interfaces import FederatedConnector
from onyx.federated_connectors.slack.federated_connector import SlackFederatedConnector
from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP
from onyx.utils.logger import setup_logger
logger = setup_logger()
class FederatedConnectorMissingException(Exception):
pass
# Cache for already imported federated connector classes
_federated_connector_cache: dict[FederatedConnectorSource, Type[FederatedConnector]] = (
{}
)
def _load_federated_connector_class(
source: FederatedConnectorSource,
) -> Type[FederatedConnector]:
"""Dynamically load and cache a federated connector class."""
if source in _federated_connector_cache:
return _federated_connector_cache[source]
if source not in FEDERATED_CONNECTOR_CLASS_MAP:
raise FederatedConnectorMissingException(
f"Federated connector not found for source={source}"
)
mapping = FEDERATED_CONNECTOR_CLASS_MAP[source]
try:
module = importlib.import_module(mapping.module_path)
connector_class = getattr(module, mapping.class_name)
_federated_connector_cache[source] = connector_class
return connector_class
except (ImportError, AttributeError) as e:
raise FederatedConnectorMissingException(
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
)
def get_federated_connector(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -21,9 +58,6 @@ def get_federated_connector(
def get_federated_connector_cls(
source: FederatedConnectorSource,
) -> type[FederatedConnector]:
) -> Type[FederatedConnector]:
"""Get the class of the appropriate federated connector."""
if source == FederatedConnectorSource.FEDERATED_SLACK:
return SlackFederatedConnector
else:
raise ValueError(f"Unsupported federated connector source: {source}")
return _load_federated_connector_class(source)

View File

@@ -0,0 +1,19 @@
"""Registry mapping for federated connector classes."""
from pydantic import BaseModel
from onyx.configs.constants import FederatedConnectorSource
class FederatedConnectorMapping(BaseModel):
module_path: str
class_name: str
# Mapping of FederatedConnectorSource to connector details for lazy loading
FEDERATED_CONNECTOR_CLASS_MAP = {
FederatedConnectorSource.FEDERATED_SLACK: FederatedConnectorMapping(
module_path="onyx.federated_connectors.slack.federated_connector",
class_name="SlackFederatedConnector",
),
}

View File

@@ -22,8 +22,6 @@ from zipfile import BadZipFile
import chardet
import openpyxl
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
@@ -272,6 +270,9 @@ def read_pdf_file(
"""
Returns the text, basic PDF metadata, and optionally extracted images.
"""
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
metadata: dict[str, Any] = {}
extracted_images: list[tuple[bytes, str]] = []
try:
@@ -313,10 +314,8 @@ def read_pdf_file(
image.save(img_byte_arr, format=image.format)
img_bytes = img_byte_arr.getvalue()
image_name = (
f"page_{page_num + 1}_image_{image_file_object.name}."
f"{image.format.lower() if image.format else 'png'}"
)
image_format = image.format.lower() if image.format else "png"
image_name = f"page_{page_num + 1}_image_{image_file_object.name}.{image_format}"
if image_callback is not None:
# Stream image out immediately
image_callback(img_bytes, image_name)
@@ -483,8 +482,7 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
" skipping rest of file"
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
)
break
sheet_str = "\n".join(rows)
@@ -556,8 +554,7 @@ def extract_file_text(
return unstructured_to_text(file, file_name)
except Exception as unstructured_error:
logger.error(
f"Failed to process with Unstructured: {str(unstructured_error)}. "
"Falling back to normal processing."
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
if extension is None:
extension = get_file_ext(file_name)
@@ -643,8 +640,7 @@ def _extract_text_and_images(
)
except Exception as e:
logger.error(
f"Failed to process with Unstructured: {str(e)}. "
"Falling back to normal processing."
f"Failed to process with Unstructured: {str(e)}. Falling back to normal processing."
)
file.seek(0) # Reset file pointer just in case

View File

@@ -5,8 +5,6 @@ from io import BytesIO
from typing import IO
import bs4
import trafilatura # type: ignore
from trafilatura.settings import use_config # type: ignore
from onyx.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
from onyx.configs.app_configs import PARSE_WITH_TRAFILATURA
@@ -56,6 +54,9 @@ def format_element_text(element_text: str, link_href: str | None) -> str:
def parse_html_with_trafilatura(html_content: str) -> str:
"""Parse HTML content using trafilatura."""
import trafilatura # type: ignore
from trafilatura.settings import use_config # type: ignore
config = use_config()
config.set("DEFAULT", "include_links", "True")
config.set("DEFAULT", "include_tables", "True")

View File

@@ -1,17 +1,16 @@
from typing import Any
from typing import cast
from typing import IO
from unstructured.staging.base import dict_to_elements
from unstructured_client import UnstructuredClient # type: ignore
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import shared
from typing import TYPE_CHECKING
from onyx.configs.constants import KV_UNSTRUCTURED_API_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from unstructured_client.models import operations # type: ignore
logger = setup_logger()
@@ -36,7 +35,10 @@ def delete_unstructured_api_key() -> None:
def _sdk_partition_request(
file: IO[Any], file_name: str, **kwargs: Any
) -> operations.PartitionRequest:
) -> "operations.PartitionRequest":
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import shared
file.seek(0, 0)
try:
request = operations.PartitionRequest(
@@ -52,6 +54,9 @@ def _sdk_partition_request(
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
from unstructured.staging.base import dict_to_elements
from unstructured_client import UnstructuredClient # type: ignore
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="fast")

View File

@@ -3,7 +3,6 @@ from collections import defaultdict
from typing import cast
import numpy as np
from nltk import ngrams # type: ignore
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
from sqlalchemy import desc
from sqlalchemy import Float
@@ -59,6 +58,8 @@ def _normalize_one_entity(
attributes: dict[str, str],
allowed_docs_temp_view_name: str | None = None,
) -> str | None:
from nltk import ngrams # type: ignore
"""
Matches a single entity to the best matching entity of the same type.
"""

View File

@@ -5,8 +5,9 @@ from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
import litellm # type: ignore
from httpx import RemoteProtocolError
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessage
@@ -24,9 +25,7 @@ from langchain_core.messages import SystemMessageChunk
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from litellm.utils import get_supported_openai_params
from onyx.configs.app_configs import BRAINTRUST_ENABLED
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.configs.chat_configs import QA_TIMEOUT
@@ -45,13 +44,9 @@ from onyx.utils.long_term_log import LongTermLogger
logger = setup_logger()
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params = True
litellm.telemetry = False
if TYPE_CHECKING:
from litellm import ModelResponse, CustomStreamWrapper, Message
if BRAINTRUST_ENABLED:
litellm.callbacks = ["braintrust"]
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
@@ -85,8 +80,10 @@ def _base_msg_to_role(msg: BaseMessage) -> str:
def _convert_litellm_message_to_langchain_message(
litellm_message: litellm.Message,
litellm_message: "Message",
) -> BaseMessage:
from onyx.llm.litellm_singleton import litellm
# Extracting the basic attributes from the litellm message
content = litellm_message.content or ""
role = litellm_message.role
@@ -176,15 +173,15 @@ def _convert_delta_to_message_chunk(
curr_msg: BaseMessage | None,
stop_reason: str | None = None,
) -> BaseMessageChunk:
from litellm.utils import ChatCompletionDeltaToolCall
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
content = _dict.get("content") or ""
additional_kwargs = {}
if _dict.get("function_call"):
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
tool_calls = cast(
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
)
tool_calls = cast(list[ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls"))
if role == "user":
return HumanMessageChunk(content=content)
@@ -321,6 +318,8 @@ class DefaultMultiLLM(LLM):
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
try:
from litellm.utils import get_supported_openai_params
params = get_supported_openai_params(model_name, model_provider)
if STANDARD_MAX_TOKENS_KWARG in (params or []):
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
@@ -388,11 +387,13 @@ class DefaultMultiLLM(LLM):
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
) -> Union["ModelResponse", "CustomStreamWrapper"]:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
processed_prompt = _prompt_to_dict(prompt)
self._record_call(processed_prompt)
from onyx.llm.litellm_singleton import litellm
from litellm.exceptions import Timeout, RateLimitError
try:
return litellm.completion(
@@ -456,12 +457,13 @@ class DefaultMultiLLM(LLM):
**self._model_kwargs,
)
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
if isinstance(e, litellm.Timeout):
if isinstance(e, Timeout):
raise LLMTimeoutError(e)
elif isinstance(e, litellm.RateLimitError):
elif isinstance(e, RateLimitError):
raise LLMRateLimitError(e)
raise e
@@ -495,11 +497,13 @@ class DefaultMultiLLM(LLM):
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
from litellm import ModelResponse
if LOG_ONYX_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
litellm.ModelResponse,
ModelResponse,
self._completion(
prompt=prompt,
tools=tools,
@@ -528,6 +532,8 @@ class DefaultMultiLLM(LLM):
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
from litellm import CustomStreamWrapper
if LOG_ONYX_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -544,7 +550,7 @@ class DefaultMultiLLM(LLM):
output = None
response = cast(
litellm.CustomStreamWrapper,
CustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,

View File

@@ -1,8 +1,5 @@
from typing import Any
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
@@ -13,6 +10,8 @@ from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input
@@ -24,13 +23,22 @@ from onyx.utils.long_term_log import LongTermLogger
logger = setup_logger()
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
def _build_provider_extra_headers(
provider: str, custom_config: dict[str, str] | None
) -> dict[str, str]:
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
return {}
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
api_key = raw_api_key.strip() if raw_api_key else None
if not api_key:
return {}
if not api_key.lower().startswith("bearer "):
api_key = f"Bearer {api_key}"
return {"Authorization": api_key}
def get_main_llm_from_tuple(
@@ -272,6 +280,16 @@ def get_llm(
) -> LLM:
if temperature is None:
temperature = GEN_AI_TEMPERATURE
extra_headers = build_llm_extra_headers(additional_headers)
# NOTE: this is needed since Ollama API key is optional
# User may access Ollama cloud via locally hosted instance (logged in)
# or just via the cloud API (not logged in, using API key)
provider_extra_headers = _build_provider_extra_headers(provider, custom_config)
if provider_extra_headers:
extra_headers.update(provider_extra_headers)
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
@@ -282,8 +300,8 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
extra_headers=extra_headers,
model_kwargs={},
long_term_logger=long_term_logger,
max_input_tokens=max_input_tokens,
)

View File

@@ -121,7 +121,6 @@ class LLM(abc.ABC):
) -> BaseMessage:
raise NotImplementedError
@traced(name="stream llm", type="llm")
def stream(
self,
prompt: LanguageModelInput,

View File

@@ -0,0 +1,23 @@
"""
Singleton module for litellm configuration.
This ensures litellm is configured exactly once when first imported.
All other modules should import litellm from here instead of directly.
"""
import litellm
from onyx.configs.app_configs import BRAINTRUST_ENABLED
# Import litellm
# Configure litellm settings immediately on import
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params = True
litellm.telemetry = False
if BRAINTRUST_ENABLED:
litellm.callbacks = ["braintrust"]
# Export the configured litellm module
__all__ = ["litellm"]

View File

@@ -1,6 +1,5 @@
from enum import Enum
import litellm # type: ignore
from pydantic import BaseModel
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
@@ -39,6 +38,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
model_configurations: list[ModelConfigurationView]
default_model: str | None = None
default_fast_model: str | None = None
default_api_base: str | None = None
# set for providers like Azure, which require a deployment name.
deployment_name_required: bool = False
# set for providers like Azure, which support a single model per deployment.
@@ -86,16 +86,23 @@ OPEN_AI_VISIBLE_MODEL_NAMES = [
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
# models
BEDROCK_MODEL_NAMES = [
model
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
OLLAMA_PROVIDER_NAME = "ollama"
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
def get_bedrock_model_names() -> list[str]:
import litellm
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
# litellm has split them into two lists :(
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
if "/" not in model and "embed" not in model
][::-1]
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
return [
model
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
if "/" not in model and "embed" not in model
][::-1]
IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
@@ -103,11 +110,18 @@ IGNORABLE_ANTHROPIC_MODELS = [
"anthropic/claude-3-5-sonnet-20241022",
]
ANTHROPIC_PROVIDER_NAME = "anthropic"
ANTHROPIC_MODEL_NAMES = [
model
for model in litellm.anthropic_models
if model not in IGNORABLE_ANTHROPIC_MODELS
][::-1]
def get_anthropic_model_names() -> list[str]:
import litellm
return [
model
for model in litellm.anthropic_models
if model not in IGNORABLE_ANTHROPIC_MODELS
][::-1]
ANTHROPIC_VISIBLE_MODEL_NAMES = [
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-20250514",
@@ -155,18 +169,23 @@ VERTEXAI_VISIBLE_MODEL_NAMES = [
]
_PROVIDER_TO_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
}
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level."""
return {
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: get_bedrock_model_names(),
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
}
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
BEDROCK_PROVIDER_NAME: [],
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
}
@@ -185,6 +204,28 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
default_model="gpt-4o",
default_fast_model="gpt-4o-mini",
),
WellKnownLLMProviderDescriptor(
name=OLLAMA_PROVIDER_NAME,
display_name="Ollama",
api_key_required=False,
api_base_required=True,
api_version_required=False,
custom_config_keys=[
CustomConfigKey(
name=OLLAMA_API_KEY_CONFIG_KEY,
display_name="Ollama API Key",
description="Optional API key used when connecting to Ollama Cloud (i.e. API base is https://ollama.com).",
is_required=False,
is_secret=True,
)
],
model_configurations=fetch_model_configurations_for_provider(
OLLAMA_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
default_api_base="http://127.0.0.1:11434",
),
WellKnownLLMProviderDescriptor(
name=ANTHROPIC_PROVIDER_NAME,
display_name="Anthropic",
@@ -248,7 +289,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
model_configurations=fetch_model_configurations_for_provider(
BEDROCK_PROVIDER_NAME
),
default_model=BEDROCK_DEFAULT_MODEL,
default_model=None,
default_fast_model=None,
),
WellKnownLLMProviderDescriptor(
@@ -287,7 +328,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
def fetch_models_for_provider(provider_name: str) -> list[str]:
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])
return _get_provider_to_models_map().get(provider_name, [])
def fetch_model_names_for_provider_as_set(provider_name: str) -> set[str] | None:

View File

@@ -16,6 +16,7 @@ from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy import select
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
@@ -26,6 +27,9 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
@@ -460,6 +464,7 @@ def get_llm_contextual_cost(
this does not account for the cost of documents that fit within a single chunk
which do not get contextualized.
"""
import litellm
# calculate input costs
@@ -639,6 +644,30 @@ def get_max_input_tokens_from_llm_provider(
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
# TODO: Add support to check model config for any provider
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here
if model_provider == "ollama":
try:
with get_session_with_current_tenant() as db_session:
model_config = db_session.scalar(
select(ModelConfiguration)
.join(
LLMProvider,
ModelConfiguration.llm_provider_id == LLMProvider.id,
)
.where(
ModelConfiguration.name == model_name,
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
)
model_map = get_model_map()
try:
model_obj = find_model_obj(

View File

@@ -25,7 +25,6 @@ from retry import retry
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from onyx.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
@@ -53,6 +52,7 @@ from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import SKIP_WARM_UP
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@@ -1115,6 +1115,9 @@ def warm_up_cross_encoder(
rerank_model_name: str,
non_blocking: bool = False,
) -> None:
if SKIP_WARM_UP:
return
logger.debug(f"Warming up reranking model: {rerank_model_name}")
reranking_model = RerankingModel(

View File

@@ -113,7 +113,7 @@ def _check_tokenizer_cache(
logger.info(
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
tokenizer = _get_default_tokenizer()
_TOKENIZER_CACHE[id_tuple] = tokenizer
@@ -150,7 +150,15 @@ def _try_initialize_tokenizer(
return None
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
_DEFAULT_TOKENIZER: BaseTokenizer | None = None
def _get_default_tokenizer() -> BaseTokenizer:
"""Lazy-load the default tokenizer to avoid loading it at module import time."""
global _DEFAULT_TOKENIZER
if _DEFAULT_TOKENIZER is None:
_DEFAULT_TOKENIZER = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
return _DEFAULT_TOKENIZER
def get_tokenizer(
@@ -163,7 +171,7 @@ def get_tokenizer(
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _get_default_tokenizer()
return _check_tokenizer_cache(provider_type, model_name)

View File

@@ -171,6 +171,14 @@ written as a list of one question.
""",
DRPath.CLOSER.value: f"""if the tool is {CLOSER}, the list of questions should simply be \
['Answer the original question with the information you have.'].
""",
DRPath.IMAGE_GENERATION.value: """
if the tool is Image Generation, respond with a list that contains exactly one JSON object
string describing the tool call. The JSON must include a "prompt" field with the text to
render. When the user specifies or implies an orientation, also include a "shape" field whose
value is one of "square", "landscape", or "portrait" (use "landscape" for wide/horizontal
requests and "portrait" for tall/vertical ones). Example: {"prompt": "Create a poster of a
coral reef", "shape": "landscape"}. Do not surround the JSON with backticks or narration.
""",
}

View File

@@ -1,136 +0,0 @@
# NOTE No longer used. This needs to be revisited later.
import re
from collections.abc import Iterator
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamingError
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
from onyx.llm.utils import message_generator_to_string_generator
from onyx.llm.utils import message_to_string
from onyx.prompts.constants import ANSWERABLE_PAT
from onyx.prompts.constants import THOUGHT_PAT
from onyx.prompts.query_validation import ANSWERABLE_PROMPT
from onyx.server.query_and_chat.models import QueryValidationResponse
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": ANSWERABLE_PROMPT.format(user_query=user_query),
},
]
return messages
def extract_answerability_reasoning(model_raw: str) -> str:
reasoning_match = re.search(
f"{THOUGHT_PAT.upper()}(.*?){ANSWERABLE_PAT.upper()}", model_raw, re.DOTALL
)
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
return reasoning_text
def extract_answerability_bool(model_raw: str) -> bool:
answerable_match = re.search(f"{ANSWERABLE_PAT.upper()}(.+)", model_raw)
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
return answerable
def get_query_answerability(
user_query: str, skip_check: bool = False
) -> tuple[str, bool]:
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
try:
llm, _ = get_default_llms()
except GenAIDisabledException:
return "Generative AI is turned off - skipping check", True
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
return reasoning, answerable
def stream_query_answerability(
user_query: str, skip_check: bool = False
) -> Iterator[str]:
if skip_check:
yield get_json_line(
QueryValidationResponse(
reasoning="Query Answerability Evaluation feature is turned off",
answerable=True,
).model_dump()
)
return
try:
llm, _ = get_default_llms()
except GenAIDisabledException:
yield get_json_line(
QueryValidationResponse(
reasoning="Generative AI is turned off - skipping check",
answerable=True,
).model_dump()
)
return
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
try:
tokens = message_generator_to_string_generator(llm.stream(filled_llm_prompt))
reasoning_pat_found = False
model_output = ""
hold_answerable = ""
for token in tokens:
model_output = model_output + token
if ANSWERABLE_PAT.upper() in model_output:
continue
if not reasoning_pat_found and THOUGHT_PAT.upper() in model_output:
reasoning_pat_found = True
reason_ind = model_output.find(THOUGHT_PAT.upper())
remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :]
if remaining:
yield get_json_line(
OnyxAnswerPiece(answer_piece=remaining).model_dump()
)
continue
if reasoning_pat_found:
hold_answerable = hold_answerable + token
if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]:
continue
yield get_json_line(
OnyxAnswerPiece(answer_piece=hold_answerable).model_dump()
)
hold_answerable = ""
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
yield get_json_line(
QueryValidationResponse(
reasoning=reasoning, answerable=answerable
).model_dump()
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
error = StreamingError(error=str(e))
yield get_json_line(error.model_dump())
logger.exception("Failed to validate Query")
return

View File

@@ -3,7 +3,6 @@ from typing import Any
from typing import cast
from typing import List
from litellm import get_supported_openai_params
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
@@ -123,6 +122,8 @@ def generate_starter_messages(
"""
_, fast_llm = get_default_llms(temperature=0.5)
from litellm.utils import get_supported_openai_params
provider = fast_llm.config.model_provider
model = fast_llm.config.model_name

View File

@@ -184,7 +184,7 @@ def seed_initial_documents(
"base_url": "https://docs.onyx.app/",
"web_connector_type": "recursive",
},
refresh_freq=None, # Never refresh by default
refresh_freq=3600, # 1 hour
prune_freq=None,
indexing_start=None,
)

View File

@@ -22,7 +22,7 @@ from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from onyx.utils.subclasses import find_all_subclasses_in_dir
from onyx.utils.subclasses import find_all_subclasses_in_package
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -44,7 +44,8 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
return _OAUTH_CONNECTORS
oauth_connectors = find_all_subclasses_in_dir(
# Import submodules using package-based discovery to avoid sys.path mutations
oauth_connectors = find_all_subclasses_in_package(
cast(type[OAuthConnector], OAuthConnector), "onyx.connectors"
)

View File

@@ -1218,7 +1218,10 @@ def _upsert_mcp_server(
logger.info(f"Created new MCP server '{request.name}' with ID {mcp_server.id}")
if not changing_connection_config:
if (
not changing_connection_config
or request.auth_type == MCPAuthenticationType.NONE
):
return mcp_server
# Create connection configs

View File

@@ -162,7 +162,7 @@ def unlink_user_file_from_project(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
@@ -210,7 +210,7 @@ def link_user_file_to_project(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.HIGHEST,
)
logger.info(
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"

View File

@@ -1,5 +1,6 @@
import math
from datetime import datetime
from datetime import timezone
from fastapi import APIRouter
from fastapi import Depends
@@ -23,8 +24,8 @@ router = APIRouter(prefix="/gpts")
def time_ago(dt: datetime) -> str:
# Calculate time difference
now = datetime.now()
diff = now - dt
now = datetime.now(timezone.utc)
diff = now - dt.astimezone(timezone.utc)
# Convert difference to minutes
minutes = diff.total_seconds() / 60

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from datetime import timezone
import boto3
import httpx
from botocore.exceptions import BotoCoreError
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
@@ -11,10 +12,12 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.db.engine.sql_engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
@@ -27,8 +30,8 @@ from onyx.db.models import User
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
from onyx.llm.llm_provider_options import BEDROCK_MODEL_NAMES
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import get_bedrock_model_names
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
@@ -40,6 +43,9 @@ from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
@@ -457,7 +463,7 @@ def get_bedrock_available_models(
# Keep only models we support (compatibility with litellm)
filtered = sorted(
[model for model in candidates if model in BEDROCK_MODEL_NAMES],
[model for model in candidates if model in get_bedrock_model_names()],
reverse=True,
)
@@ -474,3 +480,100 @@ def get_bedrock_available_models(
raise HTTPException(
status_code=500, detail=f"Unexpected error fetching Bedrock models: {e}"
)
def _get_ollama_available_model_names(api_base: str) -> set[str]:
"""Fetch available model names from Ollama server."""
tags_url = f"{api_base}/api/tags"
try:
response = httpx.get(tags_url, timeout=5.0)
response.raise_for_status()
response_json = response.json()
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch Ollama models: {e}",
)
models = response_json.get("models", [])
return {model.get("name") for model in models if model.get("name")}
@admin_router.post("/ollama/available-models")
def get_ollama_available_models(
request: OllamaModelsRequest,
_: User | None = Depends(current_admin_user),
) -> list[OllamaFinalModelResponse]:
"""Fetch the list of available models from an Ollama server."""
cleaned_api_base = request.api_base.strip().rstrip("/")
if not cleaned_api_base:
raise HTTPException(
status_code=400, detail="API base URL is required to fetch Ollama models."
)
model_names = _get_ollama_available_model_names(cleaned_api_base)
if not model_names:
raise HTTPException(
status_code=400,
detail="No models found from your Ollama server",
)
all_models_with_context_size_and_vision: list[OllamaFinalModelResponse] = []
show_url = f"{cleaned_api_base}/api/show"
for model_name in model_names:
context_limit: int | None = None
supports_image_input: bool | None = None
try:
show_response = httpx.post(
show_url,
json={"model": model_name},
timeout=5.0,
)
show_response.raise_for_status()
show_response_json = show_response.json()
# Parse the response into the expected format
ollama_model_details = OllamaModelDetails.model_validate(show_response_json)
# Check if this model supports completion/chat
if not ollama_model_details.supports_completion():
continue
# Optimistically access. Context limit is stored as "model_architecture.context" = int
architecture = ollama_model_details.model_info.get(
"general.architecture", ""
)
context_limit = ollama_model_details.model_info.get(
architecture + ".context_length", None
)
supports_image_input = ollama_model_details.supports_image_input()
except ValidationError as e:
logger.warning(
"Invalid model details from Ollama server",
extra={"model": model_name, "validation_error": str(e)},
)
except Exception as e:
logger.warning(
"Failed to fetch Ollama model details",
extra={"model": model_name, "error": str(e)},
)
# If we fail at any point attempting to extract context limit,
# still allow this model to be used with a fallback max context size
if not context_limit:
context_limit = GEN_AI_MODEL_FALLBACK_MAX_TOKENS
if not supports_image_input:
supports_image_input = False
all_models_with_context_size_and_vision.append(
OllamaFinalModelResponse(
name=model_name,
max_input_tokens=context_limit,
supports_image_input=supports_image_input,
)
)
return all_models_with_context_size_and_vision

View File

@@ -1,3 +1,4 @@
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -138,8 +139,9 @@ class LLMProviderView(LLMProvider):
class ModelConfigurationUpsertRequest(BaseModel):
name: str
is_visible: bool | None = False
is_visible: bool
max_input_tokens: int | None = None
supports_image_input: bool | None = None
@classmethod
def from_model(
@@ -149,12 +151,13 @@ class ModelConfigurationUpsertRequest(BaseModel):
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=model_configuration_model.supports_image_input,
)
class ModelConfigurationView(BaseModel):
name: str
is_visible: bool | None = False
is_visible: bool
max_input_tokens: int | None = None
supports_image_input: bool
@@ -196,3 +199,28 @@ class BedrockModelsRequest(BaseModel):
aws_secret_access_key: str | None = None
aws_bearer_token_bedrock: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class OllamaModelsRequest(BaseModel):
api_base: str
class OllamaFinalModelResponse(BaseModel):
name: str
max_input_tokens: int
supports_image_input: bool
class OllamaModelDetails(BaseModel):
"""Response model for Ollama /api/show endpoint"""
model_info: dict[str, Any]
capabilities: list[str] = []
def supports_completion(self) -> bool:
"""Check if this model supports completion/chat"""
return "completion" in self.capabilities
def supports_image_input(self) -> bool:
"""Check if this model supports image input"""
return "vision" in self.capabilities

View File

@@ -1,3 +1,5 @@
import csv
import io
import re
from datetime import datetime
from datetime import timedelta
@@ -14,6 +16,7 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -297,6 +300,43 @@ def list_all_users(
)
@router.get("/manage/users/download")
def download_users_csv(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""Download all users as a CSV file."""
# Get all users from the database
users = get_all_users(db_session)
# Create CSV content in memory
output = io.StringIO()
writer = csv.writer(output)
# Write CSV header
writer.writerow(["Email", "Role", "Status"])
# Write user data
for user in users:
writer.writerow(
[
user.email,
user.role.value if user.role else "",
"Active" if user.is_active else "Inactive",
]
)
# Prepare the CSV content for download
csv_content = output.getvalue()
output.close()
return StreamingResponse(
io.BytesIO(csv_content.encode("utf-8")),
media_type="text/csv",
headers={"Content-Disposition": "attachment;"},
)
@router.put("/manage/admin/users")
def bulk_invite_users(
emails: list[str] = Body(..., embed=True),

View File

@@ -56,8 +56,8 @@ def validate_user_token(user_token: str | None) -> None:
Raises:
HTTPException: If the token is invalid or missing required fields
"""
if user_token is None:
# user_token is optional, so None is valid
if not user_token:
# user_token is optional, so None or empty string is valid
return
if not user_token.startswith(SLACK_USER_TOKEN_PREFIX):

Some files were not shown because too many files have changed in this diff Show More