mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-11 09:52:44 +00:00
Compare commits
1 Commits
jamison/sh
...
v2.0.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
488b27ba04 |
@@ -8,9 +8,9 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -33,7 +33,16 @@ jobs:
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -46,7 +55,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -119,7 +129,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
@@ -11,8 +11,8 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -145,6 +145,15 @@ jobs:
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -157,11 +166,16 @@ jobs:
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@v3
|
||||
|
||||
@@ -7,7 +7,10 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
jobs:
|
||||
@@ -45,6 +48,15 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +69,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -126,7 +139,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
6
.github/workflows/helm-chart-releases.yml
vendored
6
.github/workflows/helm-chart-releases.yml
vendored
@@ -25,9 +25,11 @@ jobs:
|
||||
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
@@ -20,6 +20,7 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
71
.github/workflows/pr-helm-chart-testing.yml
vendored
71
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -65,35 +65,45 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -149,6 +159,7 @@ jobs:
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
@@ -156,8 +167,10 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
@@ -173,8 +186,16 @@ jobs:
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
@@ -199,7 +220,7 @@ jobs:
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -22,9 +22,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -131,6 +133,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -158,6 +161,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -191,6 +195,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -337,9 +342,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -19,9 +19,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -128,6 +130,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -155,6 +158,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -188,6 +192,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -334,9 +339,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -20,11 +20,13 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
|
||||
@@ -105,6 +105,11 @@ pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
|
||||
for raw_perm in permissions:
|
||||
if not hasattr(raw_perm, "raw"):
|
||||
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
# In order to associate this permission to some Atlassian entity, we need the "Holder".
|
||||
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
|
||||
if not permission.holder:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find a permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
type = permission.holder.get("type")
|
||||
if not type:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find the type of permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Optional
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from litellm.exceptions import RateLimitError
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@@ -207,6 +206,8 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import TypeVar
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
@@ -147,6 +145,7 @@ def invoke_llm_json(
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
|
||||
@@ -41,7 +41,7 @@ beat_task_templates: list[dict] = [
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
@@ -85,9 +85,9 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(minutes=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
|
||||
tenant_id: str,
|
||||
batch_num: int,
|
||||
) -> None:
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
if tenant_id:
|
||||
|
||||
@@ -579,6 +579,16 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
# Process each UserFile and update matching SearchDocs
|
||||
updated_count = 0
|
||||
for uf in user_files:
|
||||
@@ -586,9 +596,18 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
if doc_id.startswith("USER_FILE_CONNECTOR__"):
|
||||
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
|
||||
|
||||
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
|
||||
task_logger.debug(
|
||||
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
|
||||
)
|
||||
|
||||
if doc_id in search_doc_map:
|
||||
search_docs = search_doc_map[doc_id]
|
||||
task_logger.debug(
|
||||
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
|
||||
)
|
||||
# Update the SearchDoc to use the UserFile's UUID
|
||||
for search_doc in search_doc_map[doc_id]:
|
||||
for search_doc in search_docs:
|
||||
search_doc.document_id = str(uf.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
def get_old_index_attempts(
|
||||
@@ -21,6 +22,10 @@ def get_old_index_attempts(
|
||||
|
||||
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
|
||||
"""Clean up multiple index attempts"""
|
||||
db_session.query(IndexAttemptError).filter(
|
||||
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.query(IndexAttempt).filter(
|
||||
IndexAttempt.id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
@@ -101,7 +102,6 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian.errors import ApiError # type: ignore
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -108,6 +109,7 @@ class ConfluenceConnector(
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
self.is_cloud = is_cloud
|
||||
@@ -118,6 +120,7 @@ class ConfluenceConnector(
|
||||
self.batch_size = batch_size
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self.scoped_token = scoped_token
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
@@ -195,6 +198,7 @@ class ConfluenceConnector(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -207,6 +211,7 @@ class ConfluenceConnector(
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
timeout=3,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
|
||||
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -676,6 +681,14 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
except ApiError as e:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -87,16 +88,20 @@ class OnyxConfluence:
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
scoped_token: bool = False,
|
||||
# should generally not be passed in, but making it overridable for
|
||||
# easier testing
|
||||
confluence_user_profiles_override: list[dict[str, str]] | None = (
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
),
|
||||
) -> None:
|
||||
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
|
||||
url = scoped_url(url, "confluence") if scoped_token else url
|
||||
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
@@ -218,6 +223,34 @@ class OnyxConfluence:
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
logger.warning(
|
||||
"scoped token authenticated but not valid for probe endpoint (spaces)"
|
||||
)
|
||||
else:
|
||||
if "WWW-Authenticate" in e.response.headers:
|
||||
logger.warning(
|
||||
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
|
||||
)
|
||||
logger.warning(f"Full error: {e.response.text}")
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
@@ -236,6 +269,7 @@ class OnyxConfluence:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
@@ -304,7 +338,9 @@ class OnyxConfluence:
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
logger.info(
|
||||
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
|
||||
)
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
|
||||
@@ -5,7 +5,10 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
|
||||
def is_atlassian_date_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["cloudId"]
|
||||
|
||||
|
||||
def scoped_url(url: str, product: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
base_url = parsed.scheme + "://" + parsed.netloc
|
||||
cloud_id = get_cloudId(base_url)
|
||||
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.bitbucket.connector import BitbucketConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.imap.connector import ImapConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.outline.connector import OutlineConnector
|
||||
from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported connector classes
|
||||
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
|
||||
|
||||
|
||||
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
|
||||
"""Dynamically load and cache a connector class."""
|
||||
if source in _connector_cache:
|
||||
return _connector_cache[source]
|
||||
|
||||
if source not in CONNECTOR_CLASS_MAP:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
|
||||
mapping = CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_connector_supports_input_type(
|
||||
connector: Type[BaseConnector],
|
||||
input_type: InputType | None,
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""Validate that a connector supports the requested input type."""
|
||||
if input_type is None:
|
||||
return
|
||||
|
||||
# Check each input type requirement separately for clarity
|
||||
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
|
||||
connector, LoadConnector
|
||||
)
|
||||
|
||||
poll_unsupported = (
|
||||
input_type == InputType.POLL
|
||||
# Either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
)
|
||||
|
||||
event_unsupported = input_type == InputType.EVENT and not issubclass(
|
||||
connector, EventConnector
|
||||
)
|
||||
|
||||
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
|
||||
def identify_connector_class(
|
||||
source: DocumentSource,
|
||||
input_type: InputType | None = None,
|
||||
) -> Type[BaseConnector]:
|
||||
connector_map = {
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.OUTLINE: OutlineConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
DocumentSource.JIRA: JiraConnector,
|
||||
DocumentSource.PRODUCTBOARD: ProductboardConnector,
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
DocumentSource.DOCUMENT360: Document360Connector,
|
||||
DocumentSource.GONG: GongConnector,
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.DROPBOX: DropboxConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
DocumentSource.TEAMS: TeamsConnector,
|
||||
DocumentSource.SALESFORCE: SalesforceConnector,
|
||||
DocumentSource.DISCOURSE: DiscourseConnector,
|
||||
DocumentSource.AXERO: AxeroConnector,
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
DocumentSource.BITBUCKET: BitbucketConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
# Load the connector class using lazy loading
|
||||
connector = _load_connector_class(source)
|
||||
|
||||
if isinstance(connector_by_source, dict):
|
||||
if input_type is None:
|
||||
# If not specified, default to most exhaustive update
|
||||
connector = connector_by_source.get(InputType.LOAD_STATE)
|
||||
else:
|
||||
connector = connector_by_source.get(input_type)
|
||||
else:
|
||||
connector = connector_by_source
|
||||
if connector is None:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
# Validate connector supports the requested input_type
|
||||
_validate_connector_supports_input_type(connector, input_type, source)
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _safe_get(attr_name: str) -> str | None:
|
||||
try:
|
||||
return cast(str | None, getattr(user, attr_name))
|
||||
except GithubException:
|
||||
logger.debug(f"Error getting {attr_name} for user")
|
||||
return None
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in {
|
||||
"login": user.login,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"login": _safe_get("login"),
|
||||
"name": _safe_get("name"),
|
||||
"email": _safe_get("email"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from hubspot import HubSpot # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self._access_token = access_token
|
||||
self._portal_id: str | None = None
|
||||
self._rate_limiter = HubSpotRateLimiter()
|
||||
|
||||
# Set object types to fetch, default to all available types
|
||||
if object_types is None:
|
||||
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
"""Set the portal ID."""
|
||||
self._portal_id = value
|
||||
|
||||
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
return self._rate_limiter.call(func, *args, **kwargs)
|
||||
|
||||
def _paginated_results(
|
||||
self,
|
||||
fetch_page: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Generator[Any, None, None]:
|
||||
base_kwargs = dict(kwargs)
|
||||
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
|
||||
|
||||
after: str | None = None
|
||||
while True:
|
||||
page_kwargs = base_kwargs.copy()
|
||||
if after is not None:
|
||||
page_kwargs["after"] = after
|
||||
|
||||
page = self._call_hubspot(fetch_page, **page_kwargs)
|
||||
results = getattr(page, "results", [])
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
paging = getattr(page, "paging", None)
|
||||
next_page = getattr(paging, "next", None) if paging else None
|
||||
if next_page is None:
|
||||
break
|
||||
|
||||
after = getattr(next_page, "after", None)
|
||||
if after is None:
|
||||
break
|
||||
|
||||
def _clean_html_content(self, html_content: str) -> str:
|
||||
"""Clean HTML content and extract raw text"""
|
||||
if not html_content:
|
||||
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get associated objects for a given object"""
|
||||
try:
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=from_object_type,
|
||||
object_id=object_id,
|
||||
to_object_type=to_object_type,
|
||||
)
|
||||
|
||||
associated_objects = []
|
||||
if associations.results:
|
||||
object_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
object_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated objects
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.contacts.basic_api.get_by_id(
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
associated_objects: list[dict[str, Any]] = []
|
||||
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.companies.basic_api.get_by_id(
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.contacts.basic_api.get_by_id,
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.deals.basic_api.get_by_id(
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.companies.basic_api.get_by_id,
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.tickets.basic_api.get_by_id(
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.deals.basic_api.get_by_id,
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.tickets.basic_api.get_by_id,
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
|
||||
return associated_objects
|
||||
|
||||
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get notes associated with a given object"""
|
||||
try:
|
||||
# Get associations to notes (engagement type)
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=object_type,
|
||||
object_id=object_id,
|
||||
to_object_type="notes",
|
||||
)
|
||||
|
||||
associated_notes = []
|
||||
if associations.results:
|
||||
note_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
note_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated notes
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = api_client.crm.objects.notes.basic_api.get_by_id(
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
associated_notes = []
|
||||
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = self._call_hubspot(
|
||||
api_client.crm.objects.notes.basic_api.get_by_id,
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
|
||||
return associated_notes
|
||||
|
||||
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_tickets = api_client.crm.tickets.get_all(
|
||||
|
||||
tickets_iter = self._paginated_results(
|
||||
api_client.crm.tickets.basic_api.get_page,
|
||||
properties=[
|
||||
"subject",
|
||||
"content",
|
||||
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for ticket in all_tickets:
|
||||
for ticket in tickets_iter:
|
||||
updated_at = ticket.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_companies = api_client.crm.companies.get_all(
|
||||
|
||||
companies_iter = self._paginated_results(
|
||||
api_client.crm.companies.basic_api.get_page,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for company in all_companies:
|
||||
for company in companies_iter:
|
||||
updated_at = company.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_deals = api_client.crm.deals.get_all(
|
||||
|
||||
deals_iter = self._paginated_results(
|
||||
api_client.crm.deals.basic_api.get_page,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for deal in all_deals:
|
||||
for deal in deals_iter:
|
||||
updated_at = deal.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_contacts = api_client.crm.contacts.get_all(
|
||||
|
||||
contacts_iter = self._paginated_results(
|
||||
api_client.crm.contacts.basic_api.get_page,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for contact in all_contacts:
|
||||
for contact in contacts_iter:
|
||||
updated_at = contact.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
|
||||
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
145
backend/onyx/connectors/hubspot/rate_limit.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
RateLimitTriedTooManyTimesError,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds)
|
||||
# with a maximum of 190 requests, and a per-second limit of 19 requests.
|
||||
_HUBSPOT_TEN_SECOND_LIMIT = 190
|
||||
_HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds
|
||||
_HUBSPOT_SECONDLY_LIMIT = 19
|
||||
_HUBSPOT_SECONDLY_PERIOD = 1 # second
|
||||
_DEFAULT_SLEEP_SECONDS = 10
|
||||
_SLEEP_PADDING_SECONDS = 1.0
|
||||
_MAX_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
|
||||
def _extract_header(headers: Any, key: str) -> str | None:
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
getter = getattr(headers, "get", None)
|
||||
if callable(getter):
|
||||
value = getter(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
if isinstance(headers, dict):
|
||||
value = headers.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_rate_limit_error(exception: Exception) -> bool:
|
||||
status = getattr(exception, "status", None)
|
||||
if status == 429:
|
||||
return True
|
||||
|
||||
headers = getattr(exception, "headers", None)
|
||||
if headers is not None:
|
||||
remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining")
|
||||
if remaining == "0":
|
||||
return True
|
||||
secondly_remaining = _extract_header(
|
||||
headers, "x-hubspot-ratelimit-secondly-remaining"
|
||||
)
|
||||
if secondly_remaining == "0":
|
||||
return True
|
||||
|
||||
message = str(exception)
|
||||
return "RATE_LIMIT" in message or "Too Many Requests" in message
|
||||
|
||||
|
||||
def get_rate_limit_retry_delay_seconds(exception: Exception) -> float:
|
||||
headers = getattr(exception, "headers", None)
|
||||
|
||||
retry_after = _extract_header(headers, "Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
return float(retry_after) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse Retry-After header '%s' as float", retry_after
|
||||
)
|
||||
|
||||
interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds")
|
||||
if interval_ms:
|
||||
try:
|
||||
return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float",
|
||||
interval_ms,
|
||||
)
|
||||
|
||||
secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly")
|
||||
if secondly_limit:
|
||||
try:
|
||||
per_second = max(float(secondly_limit), 1.0)
|
||||
return (1.0 / per_second) + _SLEEP_PADDING_SECONDS
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Failed to parse x-hubspot-ratelimit-secondly '%s' as float",
|
||||
secondly_limit,
|
||||
)
|
||||
|
||||
return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS
|
||||
|
||||
|
||||
class HubSpotRateLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT,
|
||||
ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD,
|
||||
secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT,
|
||||
secondly_period: int = _HUBSPOT_SECONDLY_PERIOD,
|
||||
max_retries: int = _MAX_RATE_LIMIT_RETRIES,
|
||||
) -> None:
|
||||
self._max_retries = max_retries
|
||||
|
||||
@rate_limit_builder(max_calls=secondly_limit, period=secondly_period)
|
||||
@rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period)
|
||||
def _execute(callable_: Callable[[], T]) -> T:
|
||||
return callable_()
|
||||
|
||||
self._execute = _execute
|
||||
|
||||
def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
attempts = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return self._execute(lambda: func(*args, **kwargs))
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
if not is_rate_limit_error(exc):
|
||||
raise
|
||||
|
||||
attempts += 1
|
||||
if attempts > self._max_retries:
|
||||
raise RateLimitTriedTooManyTimesError(
|
||||
"Exceeded configured HubSpot rate limit retries"
|
||||
) from exc
|
||||
|
||||
wait_time = get_rate_limit_retry_delay_seconds(exc)
|
||||
logger.notice(
|
||||
"HubSpot rate limit reached. Sleeping %.2f seconds before retrying.",
|
||||
wait_time,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -247,7 +247,7 @@ def _perform_jql_search_v2(
|
||||
|
||||
|
||||
def process_jira_issue(
|
||||
jira_client: JIRA,
|
||||
jira_base_url: str,
|
||||
issue: Issue,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
@@ -281,7 +281,7 @@ def process_jira_issue(
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
page_url = build_jira_url(jira_base_url, issue.key)
|
||||
|
||||
metadata_dict: dict[str, str | list[str]] = {}
|
||||
people = set()
|
||||
@@ -359,7 +359,9 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
|
||||
class JiraConnector(
|
||||
CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint], SlimConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@@ -372,15 +374,23 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
# Custom JQL query to filter Jira issues
|
||||
jql_query: str | None = None,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
# dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net
|
||||
# when making jira requests but still want correct links to issues in the UI.
|
||||
# So, the user's base url is stored here, but converted to a scoped url when passed
|
||||
# to the jira client.
|
||||
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
|
||||
self.jira_project = project_key
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.jql_query = jql_query
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self._jira_client: JIRA | None = None
|
||||
# Cache project permissions to avoid fetching them repeatedly across runs
|
||||
self._project_permissions_cache: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
@@ -399,10 +409,26 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
return ""
|
||||
return f'"{self.jira_project}"'
|
||||
|
||||
def _get_project_permissions(self, project_key: str) -> Any:
|
||||
"""Get project permissions with caching.
|
||||
|
||||
Args:
|
||||
project_key: The Jira project key
|
||||
|
||||
Returns:
|
||||
The external access permissions for the project
|
||||
"""
|
||||
if project_key not in self._project_permissions_cache:
|
||||
self._project_permissions_cache[project_key] = get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
)
|
||||
return self._project_permissions_cache[project_key]
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
credentials=credentials,
|
||||
jira_base=self.jira_base,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -442,15 +468,37 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=False
|
||||
)
|
||||
raise e
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
"""Load documents from checkpoint with permission information included."""
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint, include_permissions=True)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(
|
||||
jql, checkpoint, include_permissions=True
|
||||
)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
@@ -472,18 +520,25 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
jira_base_url=self.jira_base,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
project_key = get_jira_project_key_from_issue(issue=issue)
|
||||
if project_key:
|
||||
document.external_access = self._get_project_permissions(
|
||||
project_key
|
||||
)
|
||||
yield document
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
document_link=build_jira_url(self.jira_base, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
@@ -534,6 +589,7 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_doc_batch = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
@@ -550,13 +606,12 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
continue
|
||||
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
id = build_jira_url(self.jira_base, issue_key)
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
external_access=get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
),
|
||||
external_access=self._get_project_permissions(project_key),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -10,6 +10,7 @@ from jira.resources import CustomFieldOption
|
||||
from jira.resources import Issue
|
||||
from jira.resources import User
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -74,11 +75,18 @@ def extract_text_from_adf(adf: dict | None) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
|
||||
return f"{jira_client.client_info()}/browse/{issue_key}"
|
||||
def build_jira_url(jira_base_url: str, issue_key: str) -> str:
|
||||
"""
|
||||
Get the url used to access an issue in the UI.
|
||||
"""
|
||||
return f"{jira_base_url}/browse/{issue_key}"
|
||||
|
||||
|
||||
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
def build_jira_client(
|
||||
credentials: dict[str, Any], jira_base: str, scoped_token: bool = False
|
||||
) -> JIRA:
|
||||
|
||||
jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base
|
||||
api_token = credentials["jira_api_token"]
|
||||
# if user provide an email we assume it's cloud
|
||||
if "jira_user_email" in credentials:
|
||||
|
||||
208
backend/onyx/connectors/registry.py
Normal file
208
backend/onyx/connectors/registry.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Registry mapping for connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
class ConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of DocumentSource to connector details for lazy loading
|
||||
CONNECTOR_CLASS_MAP = {
|
||||
DocumentSource.WEB: ConnectorMapping(
|
||||
module_path="onyx.connectors.web.connector",
|
||||
class_name="WebConnector",
|
||||
),
|
||||
DocumentSource.FILE: ConnectorMapping(
|
||||
module_path="onyx.connectors.file.connector",
|
||||
class_name="LocalFileConnector",
|
||||
),
|
||||
DocumentSource.SLACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.slack.connector",
|
||||
class_name="SlackConnector",
|
||||
),
|
||||
DocumentSource.GITHUB: ConnectorMapping(
|
||||
module_path="onyx.connectors.github.connector",
|
||||
class_name="GithubConnector",
|
||||
),
|
||||
DocumentSource.GMAIL: ConnectorMapping(
|
||||
module_path="onyx.connectors.gmail.connector",
|
||||
class_name="GmailConnector",
|
||||
),
|
||||
DocumentSource.GITLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitlab.connector",
|
||||
class_name="GitlabConnector",
|
||||
),
|
||||
DocumentSource.GITBOOK: ConnectorMapping(
|
||||
module_path="onyx.connectors.gitbook.connector",
|
||||
class_name="GitbookConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_DRIVE: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_drive.connector",
|
||||
class_name="GoogleDriveConnector",
|
||||
),
|
||||
DocumentSource.BOOKSTACK: ConnectorMapping(
|
||||
module_path="onyx.connectors.bookstack.connector",
|
||||
class_name="BookstackConnector",
|
||||
),
|
||||
DocumentSource.OUTLINE: ConnectorMapping(
|
||||
module_path="onyx.connectors.outline.connector",
|
||||
class_name="OutlineConnector",
|
||||
),
|
||||
DocumentSource.CONFLUENCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.confluence.connector",
|
||||
class_name="ConfluenceConnector",
|
||||
),
|
||||
DocumentSource.JIRA: ConnectorMapping(
|
||||
module_path="onyx.connectors.jira.connector",
|
||||
class_name="JiraConnector",
|
||||
),
|
||||
DocumentSource.PRODUCTBOARD: ConnectorMapping(
|
||||
module_path="onyx.connectors.productboard.connector",
|
||||
class_name="ProductboardConnector",
|
||||
),
|
||||
DocumentSource.SLAB: ConnectorMapping(
|
||||
module_path="onyx.connectors.slab.connector",
|
||||
class_name="SlabConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
),
|
||||
DocumentSource.ZULIP: ConnectorMapping(
|
||||
module_path="onyx.connectors.zulip.connector",
|
||||
class_name="ZulipConnector",
|
||||
),
|
||||
DocumentSource.GURU: ConnectorMapping(
|
||||
module_path="onyx.connectors.guru.connector",
|
||||
class_name="GuruConnector",
|
||||
),
|
||||
DocumentSource.LINEAR: ConnectorMapping(
|
||||
module_path="onyx.connectors.linear.connector",
|
||||
class_name="LinearConnector",
|
||||
),
|
||||
DocumentSource.HUBSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.hubspot.connector",
|
||||
class_name="HubSpotConnector",
|
||||
),
|
||||
DocumentSource.DOCUMENT360: ConnectorMapping(
|
||||
module_path="onyx.connectors.document360.connector",
|
||||
class_name="Document360Connector",
|
||||
),
|
||||
DocumentSource.GONG: ConnectorMapping(
|
||||
module_path="onyx.connectors.gong.connector",
|
||||
class_name="GongConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_SITES: ConnectorMapping(
|
||||
module_path="onyx.connectors.google_site.connector",
|
||||
class_name="GoogleSitesConnector",
|
||||
),
|
||||
DocumentSource.ZENDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.zendesk.connector",
|
||||
class_name="ZendeskConnector",
|
||||
),
|
||||
DocumentSource.LOOPIO: ConnectorMapping(
|
||||
module_path="onyx.connectors.loopio.connector",
|
||||
class_name="LoopioConnector",
|
||||
),
|
||||
DocumentSource.DROPBOX: ConnectorMapping(
|
||||
module_path="onyx.connectors.dropbox.connector",
|
||||
class_name="DropboxConnector",
|
||||
),
|
||||
DocumentSource.SHAREPOINT: ConnectorMapping(
|
||||
module_path="onyx.connectors.sharepoint.connector",
|
||||
class_name="SharepointConnector",
|
||||
),
|
||||
DocumentSource.TEAMS: ConnectorMapping(
|
||||
module_path="onyx.connectors.teams.connector",
|
||||
class_name="TeamsConnector",
|
||||
),
|
||||
DocumentSource.SALESFORCE: ConnectorMapping(
|
||||
module_path="onyx.connectors.salesforce.connector",
|
||||
class_name="SalesforceConnector",
|
||||
),
|
||||
DocumentSource.DISCOURSE: ConnectorMapping(
|
||||
module_path="onyx.connectors.discourse.connector",
|
||||
class_name="DiscourseConnector",
|
||||
),
|
||||
DocumentSource.AXERO: ConnectorMapping(
|
||||
module_path="onyx.connectors.axero.connector",
|
||||
class_name="AxeroConnector",
|
||||
),
|
||||
DocumentSource.CLICKUP: ConnectorMapping(
|
||||
module_path="onyx.connectors.clickup.connector",
|
||||
class_name="ClickupConnector",
|
||||
),
|
||||
DocumentSource.MEDIAWIKI: ConnectorMapping(
|
||||
module_path="onyx.connectors.mediawiki.wiki",
|
||||
class_name="MediaWikiConnector",
|
||||
),
|
||||
DocumentSource.WIKIPEDIA: ConnectorMapping(
|
||||
module_path="onyx.connectors.wikipedia.connector",
|
||||
class_name="WikipediaConnector",
|
||||
),
|
||||
DocumentSource.ASANA: ConnectorMapping(
|
||||
module_path="onyx.connectors.asana.connector",
|
||||
class_name="AsanaConnector",
|
||||
),
|
||||
DocumentSource.S3: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.R2: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.OCI_STORAGE: ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
),
|
||||
DocumentSource.XENFORO: ConnectorMapping(
|
||||
module_path="onyx.connectors.xenforo.connector",
|
||||
class_name="XenforoConnector",
|
||||
),
|
||||
DocumentSource.DISCORD: ConnectorMapping(
|
||||
module_path="onyx.connectors.discord.connector",
|
||||
class_name="DiscordConnector",
|
||||
),
|
||||
DocumentSource.FRESHDESK: ConnectorMapping(
|
||||
module_path="onyx.connectors.freshdesk.connector",
|
||||
class_name="FreshdeskConnector",
|
||||
),
|
||||
DocumentSource.FIREFLIES: ConnectorMapping(
|
||||
module_path="onyx.connectors.fireflies.connector",
|
||||
class_name="FirefliesConnector",
|
||||
),
|
||||
DocumentSource.EGNYTE: ConnectorMapping(
|
||||
module_path="onyx.connectors.egnyte.connector",
|
||||
class_name="EgnyteConnector",
|
||||
),
|
||||
DocumentSource.AIRTABLE: ConnectorMapping(
|
||||
module_path="onyx.connectors.airtable.airtable_connector",
|
||||
class_name="AirtableConnector",
|
||||
),
|
||||
DocumentSource.HIGHSPOT: ConnectorMapping(
|
||||
module_path="onyx.connectors.highspot.connector",
|
||||
class_name="HighspotConnector",
|
||||
),
|
||||
DocumentSource.IMAP: ConnectorMapping(
|
||||
module_path="onyx.connectors.imap.connector",
|
||||
class_name="ImapConnector",
|
||||
),
|
||||
DocumentSource.BITBUCKET: ConnectorMapping(
|
||||
module_path="onyx.connectors.bitbucket.connector",
|
||||
class_name="BitbucketConnector",
|
||||
),
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
|
||||
module_path="onyx.connectors.mock_connector.connector",
|
||||
class_name="MockConnector",
|
||||
),
|
||||
}
|
||||
@@ -219,6 +219,25 @@ def is_valid_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _same_site(base_url: str, candidate_url: str) -> bool:
|
||||
base, candidate = urlparse(base_url), urlparse(candidate_url)
|
||||
base_netloc = base.netloc.lower().removeprefix("www.")
|
||||
candidate_netloc = candidate.netloc.lower().removeprefix("www.")
|
||||
if base_netloc != candidate_netloc:
|
||||
return False
|
||||
|
||||
base_path = (base.path or "/").rstrip("/")
|
||||
if base_path in ("", "/"):
|
||||
return True
|
||||
|
||||
candidate_path = candidate.path or "/"
|
||||
if candidate_path == base_path:
|
||||
return True
|
||||
|
||||
boundary = f"{base_path}/"
|
||||
return candidate_path.startswith(boundary)
|
||||
|
||||
|
||||
def get_internal_links(
|
||||
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
|
||||
) -> set[str]:
|
||||
@@ -239,7 +258,7 @@ def get_internal_links(
|
||||
# Relative path handling
|
||||
href = urljoin(url, href)
|
||||
|
||||
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
|
||||
if _same_site(base_url, href):
|
||||
internal_links.add(href)
|
||||
return internal_links
|
||||
|
||||
|
||||
@@ -1,15 +1,52 @@
|
||||
"""Factory for creating federated connector instances."""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.federated_connectors.interfaces import FederatedConnector
|
||||
from onyx.federated_connectors.slack.federated_connector import SlackFederatedConnector
|
||||
from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FederatedConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported federated connector classes
|
||||
_federated_connector_cache: dict[FederatedConnectorSource, Type[FederatedConnector]] = (
|
||||
{}
|
||||
)
|
||||
|
||||
|
||||
def _load_federated_connector_class(
|
||||
source: FederatedConnectorSource,
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Dynamically load and cache a federated connector class."""
|
||||
if source in _federated_connector_cache:
|
||||
return _federated_connector_cache[source]
|
||||
|
||||
if source not in FEDERATED_CONNECTOR_CLASS_MAP:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Federated connector not found for source={source}"
|
||||
)
|
||||
|
||||
mapping = FEDERATED_CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_federated_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise FederatedConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def get_federated_connector(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -21,9 +58,6 @@ def get_federated_connector(
|
||||
|
||||
def get_federated_connector_cls(
|
||||
source: FederatedConnectorSource,
|
||||
) -> type[FederatedConnector]:
|
||||
) -> Type[FederatedConnector]:
|
||||
"""Get the class of the appropriate federated connector."""
|
||||
if source == FederatedConnectorSource.FEDERATED_SLACK:
|
||||
return SlackFederatedConnector
|
||||
else:
|
||||
raise ValueError(f"Unsupported federated connector source: {source}")
|
||||
return _load_federated_connector_class(source)
|
||||
|
||||
19
backend/onyx/federated_connectors/registry.py
Normal file
19
backend/onyx/federated_connectors/registry.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Registry mapping for federated connector classes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
|
||||
|
||||
class FederatedConnectorMapping(BaseModel):
|
||||
module_path: str
|
||||
class_name: str
|
||||
|
||||
|
||||
# Mapping of FederatedConnectorSource to connector details for lazy loading
|
||||
FEDERATED_CONNECTOR_CLASS_MAP = {
|
||||
FederatedConnectorSource.FEDERATED_SLACK: FederatedConnectorMapping(
|
||||
module_path="onyx.federated_connectors.slack.federated_connector",
|
||||
class_name="SlackFederatedConnector",
|
||||
),
|
||||
}
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import shared
|
||||
@@ -52,6 +51,8 @@ def _sdk_partition_request(
|
||||
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
@@ -24,9 +25,7 @@ from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
@@ -45,13 +44,9 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse, CustomStreamWrapper, Message
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
@@ -85,8 +80,10 @@ def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
|
||||
|
||||
def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: litellm.Message,
|
||||
litellm_message: "Message",
|
||||
) -> BaseMessage:
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
@@ -176,15 +173,15 @@ def _convert_delta_to_message_chunk(
|
||||
curr_msg: BaseMessage | None,
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
from litellm.utils import ChatCompletionDeltaToolCall
|
||||
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
|
||||
tool_calls = cast(
|
||||
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
|
||||
)
|
||||
tool_calls = cast(list[ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls"))
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
@@ -321,6 +318,8 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
|
||||
try:
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
params = get_supported_openai_params(model_name, model_provider)
|
||||
if STANDARD_MAX_TOKENS_KWARG in (params or []):
|
||||
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
|
||||
@@ -388,11 +387,12 @@ class DefaultMultiLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
@@ -495,11 +495,13 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
from litellm import ModelResponse
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
ModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
@@ -528,6 +530,8 @@ class DefaultMultiLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
from litellm import CustomStreamWrapper
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
@@ -544,7 +548,7 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
|
||||
23
backend/onyx/llm/litellm_singleton.py
Normal file
23
backend/onyx/llm/litellm_singleton.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Singleton module for litellm configuration.
|
||||
This ensures litellm is configured exactly once when first imported.
|
||||
All other modules should import litellm from here instead of directly.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
|
||||
# Import litellm
|
||||
|
||||
# Configure litellm settings immediately on import
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
# Export the configured litellm module
|
||||
__all__ = ["litellm"]
|
||||
@@ -460,6 +460,7 @@ def get_llm_contextual_cost(
|
||||
this does not account for the cost of documents that fit within a single chunk
|
||||
which do not get contextualized.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
# calculate input costs
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
@@ -123,6 +122,8 @@ def generate_starter_messages(
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ def seed_initial_documents(
|
||||
"base_url": "https://docs.onyx.app/",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
refresh_freq=None, # Never refresh by default
|
||||
refresh_freq=3600, # 1 hour
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_dir
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_package
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +44,8 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
# Import submodules using package-based discovery to avoid sys.path mutations
|
||||
oauth_connectors = find_all_subclasses_in_package(
|
||||
cast(type[OAuthConnector], OAuthConnector), "onyx.connectors"
|
||||
)
|
||||
|
||||
|
||||
@@ -1218,7 +1218,10 @@ def _upsert_mcp_server(
|
||||
|
||||
logger.info(f"Created new MCP server '{request.name}' with ID {mcp_server.id}")
|
||||
|
||||
if not changing_connection_config:
|
||||
if (
|
||||
not changing_connection_config
|
||||
or request.auth_type == MCPAuthenticationType.NONE
|
||||
):
|
||||
return mcp_server
|
||||
|
||||
# Create connection configs
|
||||
|
||||
@@ -162,7 +162,7 @@ def unlink_user_file_from_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
@@ -210,7 +210,7 @@ def link_user_file_to_project(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.HIGHEST,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered project sync for user_file_id={user_file.id} with task_id={task.id}"
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
@@ -244,6 +243,8 @@ class ImageGenerationTool(Tool[None]):
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape, format: ImageFormat
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation # type: ignore
|
||||
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
if self.model == "gpt-image-1":
|
||||
size = "1536x1024"
|
||||
|
||||
@@ -5,7 +5,6 @@ This module provides a proper MCP client that follows the JSON-RPC 2.0 specifica
|
||||
and handles connection initialization, session management, and protocol communication.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
@@ -27,6 +26,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_async_sync
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -203,13 +203,7 @@ def _call_mcp_client_function_sync(
|
||||
function, server_url, connection_headers, transport, auth, **kwargs
|
||||
)
|
||||
try:
|
||||
# Run the async function in a new event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(run_client_function())
|
||||
finally:
|
||||
loop.close()
|
||||
return run_async_sync(run_client_function())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to call MCP client function: {e}")
|
||||
if isinstance(e, ExceptionGroup):
|
||||
|
||||
@@ -56,7 +56,7 @@ class MCPTool(BaseTool):
|
||||
self._tool_definition = tool_definition
|
||||
self._description = tool_description
|
||||
self._display_name = tool_definition.get("displayName", tool_name)
|
||||
self._llm_name = f"mcp_{mcp_server.name}_{tool_name}"
|
||||
self._llm_name = f"mcp:{mcp_server.name}:{tool_name}"
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
|
||||
@@ -35,6 +35,35 @@ def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
|
||||
return imported_modules
|
||||
|
||||
|
||||
def import_all_submodules_from_package(package_name: str) -> List[ModuleType]:
|
||||
"""
|
||||
Imports all submodules of a given package WITHOUT mutating sys.path.
|
||||
Uses the package's __path__ and imports with fully-qualified names.
|
||||
"""
|
||||
imported_modules: List[ModuleType] = []
|
||||
|
||||
try:
|
||||
pkg = importlib.import_module(package_name)
|
||||
except Exception as e:
|
||||
print(f"Could not import package {package_name}: {e}")
|
||||
return imported_modules
|
||||
|
||||
pkg_paths = getattr(pkg, "__path__", None)
|
||||
if not pkg_paths:
|
||||
return imported_modules
|
||||
|
||||
for _, module_name, _ in pkgutil.walk_packages(
|
||||
pkg_paths, prefix=pkg.__name__ + "."
|
||||
):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
imported_modules.append(module)
|
||||
except Exception as e:
|
||||
print(f"Could not import {module_name}: {e}")
|
||||
|
||||
return imported_modules
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
|
||||
"""
|
||||
Recursively find all subclasses of the given class.
|
||||
@@ -65,6 +94,18 @@ def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Ty
|
||||
return subclasses
|
||||
|
||||
|
||||
def find_all_subclasses_in_package(
|
||||
parent_class: Type[T], package_name: str
|
||||
) -> List[Type[T]]:
|
||||
"""
|
||||
Imports all submodules from the given package name, then returns all subclasses
|
||||
of parent_class that are loaded in memory.
|
||||
"""
|
||||
import_all_submodules_from_package(package_name)
|
||||
subclasses = all_subclasses(parent_class)
|
||||
return subclasses
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import collections.abc
|
||||
import concurrent
|
||||
import contextvars
|
||||
import copy
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
@@ -20,6 +23,7 @@ from typing import Protocol
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic.types import T
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -279,6 +283,19 @@ def run_functions_in_parallel(
|
||||
return results
|
||||
|
||||
|
||||
def run_async_sync(coro: Awaitable[T]) -> T:
|
||||
"""
|
||||
async-to-sync converter. Basically just executes asyncio.run in a separate thread.
|
||||
Which is probably somehow inefficient or not ideal but fine for now.
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future: concurrent.futures.Future[T] = executor.submit(
|
||||
context.run, asyncio.run, coro # type: ignore[arg-type]
|
||||
)
|
||||
return future.result()
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
|
||||
@@ -53,7 +53,7 @@ oauthlib==3.2.2
|
||||
openai==1.99.5
|
||||
openpyxl==3.1.5
|
||||
passlib==1.7.4
|
||||
playwright==1.41.2
|
||||
playwright==1.55.0
|
||||
psutil==5.9.5
|
||||
psycopg2-binary==2.9.9
|
||||
puremagic==1.28
|
||||
|
||||
@@ -3,6 +3,7 @@ import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Set
|
||||
|
||||
@@ -15,7 +16,29 @@ logging.basicConfig(
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MODULES_TO_LAZY_IMPORT = {"vertexai", "openai", "markitdown", "tiktoken"}
|
||||
|
||||
@dataclass
|
||||
class LazyImportSettings:
|
||||
"""Settings for which files to ignore when checking for lazy imports."""
|
||||
|
||||
ignore_files: Set[str] | None = None
|
||||
|
||||
|
||||
# Map of modules to lazy import -> settings for what to ignore
|
||||
_LAZY_IMPORT_MODULES_TO_IGNORE_SETTINGS: Dict[str, LazyImportSettings] = {
|
||||
"vertexai": LazyImportSettings(),
|
||||
"openai": LazyImportSettings(),
|
||||
"markitdown": LazyImportSettings(),
|
||||
"tiktoken": LazyImportSettings(),
|
||||
"unstructured": LazyImportSettings(),
|
||||
"onyx.llm.litellm_singleton": LazyImportSettings(),
|
||||
"litellm": LazyImportSettings(
|
||||
ignore_files={
|
||||
"onyx/llm/llm_provider_options.py",
|
||||
"onyx/llm/litellm_singleton.py",
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,25 +111,19 @@ def find_eager_imports(
|
||||
)
|
||||
|
||||
|
||||
def find_python_files(
|
||||
backend_dir: Path, ignore_directories: Set[str] | None = None
|
||||
) -> List[Path]:
|
||||
def find_python_files(backend_dir: Path) -> List[Path]:
|
||||
"""
|
||||
Find all Python files in the backend directory, excluding test files and ignored directories.
|
||||
Find all Python files in the backend directory, excluding test files.
|
||||
|
||||
Args:
|
||||
backend_dir: Path to the backend directory to search
|
||||
ignore_directories: Set of directory names to ignore (e.g., {"model_server", "tests"})
|
||||
|
||||
Returns:
|
||||
List of Python file paths to check
|
||||
"""
|
||||
if ignore_directories is None:
|
||||
ignore_directories = set()
|
||||
|
||||
# Always ignore virtual environment directories
|
||||
venv_dirs = {".venv", "venv", ".env", "env", "__pycache__"}
|
||||
ignore_directories = ignore_directories.union(venv_dirs)
|
||||
ignore_directories = {".venv", "venv", ".env", "env", "__pycache__"}
|
||||
|
||||
python_files = []
|
||||
for file_path in backend_dir.glob("**/*.py"):
|
||||
@@ -119,8 +136,8 @@ def find_python_files(
|
||||
):
|
||||
continue
|
||||
|
||||
# Skip ignored directories (check directory names, not file names)
|
||||
if any(ignored_dir in path_parts[:-1] for ignored_dir in ignore_directories):
|
||||
# Skip ignored directories
|
||||
if any(ignored_dir in path_parts for ignored_dir in ignore_directories):
|
||||
continue
|
||||
|
||||
python_files.append(file_path)
|
||||
@@ -128,24 +145,58 @@ def find_python_files(
|
||||
return python_files
|
||||
|
||||
|
||||
def main(
|
||||
modules_to_lazy_import: Set[str], directories_to_ignore: Set[str] | None = None
|
||||
) -> None:
|
||||
def should_check_file_for_module(
|
||||
file_path: Path, backend_dir: Path, settings: LazyImportSettings
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a file should be checked for a specific module's imports.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to check
|
||||
backend_dir: Path to the backend directory
|
||||
settings: Settings containing files to ignore for this module
|
||||
|
||||
Returns:
|
||||
True if the file should be checked, False if it should be ignored
|
||||
"""
|
||||
if not settings.ignore_files:
|
||||
# Empty set means check everywhere
|
||||
return True
|
||||
|
||||
# Get relative path from backend directory
|
||||
rel_path = file_path.relative_to(backend_dir)
|
||||
rel_path_str = str(rel_path)
|
||||
|
||||
# Check if this specific file path is in the ignore list
|
||||
return rel_path_str not in settings.ignore_files
|
||||
|
||||
|
||||
def main(modules_to_lazy_import: Dict[str, LazyImportSettings]) -> None:
|
||||
backend_dir = Path(__file__).parent.parent # Go up from scripts/ to backend/
|
||||
|
||||
logger.info(
|
||||
f"Checking for direct imports of lazy modules: {', '.join(modules_to_lazy_import)}"
|
||||
f"Checking for direct imports of lazy modules: {', '.join(modules_to_lazy_import.keys())}"
|
||||
)
|
||||
|
||||
# Find all Python files to check
|
||||
target_python_files = find_python_files(backend_dir, directories_to_ignore)
|
||||
target_python_files = find_python_files(backend_dir)
|
||||
|
||||
violations_found = False
|
||||
all_violated_modules = set()
|
||||
|
||||
# Check each Python file
|
||||
# Check each Python file for each module with its specific ignore directories
|
||||
for file_path in target_python_files:
|
||||
result = find_eager_imports(file_path, modules_to_lazy_import)
|
||||
# Determine which modules should be checked for this file
|
||||
modules_to_check = set()
|
||||
for module_name, settings in modules_to_lazy_import.items():
|
||||
if should_check_file_for_module(file_path, backend_dir, settings):
|
||||
modules_to_check.add(module_name)
|
||||
|
||||
if not modules_to_check:
|
||||
# This file is ignored for all modules
|
||||
continue
|
||||
|
||||
result = find_eager_imports(file_path, modules_to_check)
|
||||
|
||||
if result.violation_lines:
|
||||
violations_found = True
|
||||
@@ -159,7 +210,7 @@ def main(
|
||||
# Suggest fix only for violated modules
|
||||
if result.violated_modules:
|
||||
logger.error(
|
||||
f" 💡 You must import {', '.join(sorted(result.violated_modules))} only within functions when needed"
|
||||
f" 💡 You must lazy import {', '.join(sorted(result.violated_modules))} within functions when needed"
|
||||
)
|
||||
|
||||
if violations_found:
|
||||
@@ -173,7 +224,7 @@ def main(
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main(_MODULES_TO_LAZY_IMPORT)
|
||||
main(_LAZY_IMPORT_MODULES_TO_IGNORE_SETTINGS)
|
||||
sys.exit(0)
|
||||
except RuntimeError:
|
||||
sys.exit(1)
|
||||
|
||||
@@ -14,13 +14,15 @@ from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
def _make_connector(
|
||||
space: str, access_token: str, scoped_token: bool = False
|
||||
) -> ConfluenceConnector:
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=space,
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
credentials_provider = OnyxStaticCredentialsProvider(
|
||||
@@ -28,13 +30,25 @@ def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
DocumentSource.CONFLUENCE,
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
"confluence_access_token": access_token,
|
||||
},
|
||||
)
|
||||
connector.set_credentials_provider(credentials_provider)
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
return _make_connector(space, os.environ["CONFLUENCE_ACCESS_TOKEN"].strip())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector_scoped(space: str) -> ConfluenceConnector:
|
||||
return _make_connector(
|
||||
space, os.environ["CONFLUENCE_ACCESS_TOKEN_SCOPED"].strip(), scoped_token=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
@@ -42,6 +56,25 @@ def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
)
|
||||
def test_confluence_connector_basic(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
_test_confluence_connector_basic(confluence_connector)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_confluence_connector_basic_scoped(
|
||||
mock_get_api_key: MagicMock, confluence_connector_scoped: ConfluenceConnector
|
||||
) -> None:
|
||||
_test_confluence_connector_basic(
|
||||
confluence_connector_scoped, expect_attachments=False
|
||||
)
|
||||
|
||||
|
||||
def _test_confluence_connector_basic(
|
||||
confluence_connector: ConfluenceConnector, expect_attachments: bool = True
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch = load_all_docs_from_checkpoint_connector(
|
||||
@@ -65,6 +98,10 @@ def test_confluence_connector_basic(
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
assert page_within_a_page_doc.primary_owners
|
||||
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert (
|
||||
page_within_a_page_doc.id
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
|
||||
)
|
||||
assert len(page_within_a_page_doc.sections) == 1
|
||||
|
||||
page_within_a_page_section = page_within_a_page_doc.sections[0]
|
||||
@@ -77,10 +114,15 @@ def test_confluence_connector_basic(
|
||||
|
||||
assert page_doc is not None
|
||||
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
|
||||
assert (
|
||||
page_doc.id == "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_doc.sections) == 2 # page text + attachment text
|
||||
assert (
|
||||
len(page_doc.sections) == 2 if expect_attachments else 1
|
||||
) # page text + attachment text
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
@@ -89,10 +131,11 @@ def test_confluence_connector_basic(
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
text_attachment_section = page_doc.sections[1]
|
||||
assert text_attachment_section.text == "small"
|
||||
assert text_attachment_section.link
|
||||
assert text_attachment_section.link.endswith("small-file.txt")
|
||||
if expect_attachments:
|
||||
text_attachment_section = page_doc.sections[1]
|
||||
assert text_attachment_section.text == "small"
|
||||
assert text_attachment_section.link
|
||||
assert text_attachment_section.link.endswith("small-file.txt")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
|
||||
@@ -131,15 +131,15 @@ def test_confluence_connector_restriction_handling(
|
||||
}
|
||||
|
||||
# if restriction is applied, only should be visible to shared users / groups
|
||||
restricted_emails = {"chris@onyx.app", "hagen@danswer.ai"}
|
||||
restricted_emails = {"chris@onyx.app", "hagen@danswer.ai", "oauth@onyx.app"}
|
||||
restricted_user_groups = {"confluence-admins-danswerai"}
|
||||
|
||||
extra_restricted_emails = {"chris@onyx.app"}
|
||||
extra_restricted_emails = {"chris@onyx.app", "oauth@onyx.app"}
|
||||
extra_restricted_user_groups: set[str] = set()
|
||||
|
||||
# note that this is only allowed since yuhong@onyx.app is a member of the
|
||||
# confluence-admins-danswerai group
|
||||
special_restricted_emails = {"chris@onyx.app", "yuhong@onyx.app"}
|
||||
special_restricted_emails = {"chris@onyx.app", "yuhong@onyx.app", "oauth@onyx.app"}
|
||||
special_restricted_user_groups: set[str] = set()
|
||||
|
||||
# Check Root+Page+2 is public
|
||||
|
||||
@@ -10,22 +10,36 @@ from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector() -> JiraConnector:
|
||||
def _make_connector(scoped_token: bool = False) -> JiraConnector:
|
||||
connector = JiraConnector(
|
||||
jira_base_url="https://danswerai.atlassian.net",
|
||||
project_key="AS",
|
||||
comment_email_blacklist=[],
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
"jira_api_token": (
|
||||
os.environ["JIRA_API_TOKEN_SCOPED"]
|
||||
if scoped_token
|
||||
else os.environ["JIRA_API_TOKEN"]
|
||||
),
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector() -> JiraConnector:
|
||||
return _make_connector()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector_scoped() -> JiraConnector:
|
||||
return _make_connector(scoped_token=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector_with_jql() -> JiraConnector:
|
||||
connector = JiraConnector(
|
||||
@@ -47,6 +61,20 @@ def jira_connector_with_jql() -> JiraConnector:
|
||||
return_value=None,
|
||||
)
|
||||
def test_jira_connector_basic(reset: None, jira_connector: JiraConnector) -> None:
|
||||
_test_jira_connector_basic(jira_connector)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_jira_connector_basic_scoped(
|
||||
reset: None, jira_connector_scoped: JiraConnector
|
||||
) -> None:
|
||||
_test_jira_connector_basic(jira_connector_scoped)
|
||||
|
||||
|
||||
def _test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=jira_connector,
|
||||
start=0,
|
||||
|
||||
@@ -71,3 +71,15 @@ def test_web_connector_bot_protection() -> None:
|
||||
doc = doc_batch[0]
|
||||
assert doc.sections[0].text is not None
|
||||
assert MERCURY_EXPECTED_QUOTE in doc.sections[0].text
|
||||
|
||||
|
||||
def test_web_connector_recursive_www_redirect() -> None:
|
||||
# Check that onyx.app can be recursed if re-directed to www.onyx.app
|
||||
connector = WebConnector(
|
||||
base_url="https://onyx.app",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
|
||||
)
|
||||
|
||||
documents = [doc for batch in connector.load_from_state() for doc in batch]
|
||||
|
||||
assert len(documents) > 1
|
||||
|
||||
@@ -25,7 +25,12 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="org-admins",
|
||||
user_emails={"founders@onyx.app", "chris@onyx.app", "yuhong@onyx.app"},
|
||||
user_emails={
|
||||
"founders@onyx.app",
|
||||
"chris@onyx.app",
|
||||
"yuhong@onyx.app",
|
||||
"oauth@onyx.app",
|
||||
},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
@@ -36,6 +41,7 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
"founders@onyx.app",
|
||||
"pablo@onyx.app",
|
||||
"yuhong@onyx.app",
|
||||
"oauth@onyx.app",
|
||||
},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
@@ -46,6 +52,7 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
"founders@onyx.app",
|
||||
"pablo@onyx.app",
|
||||
"chris@onyx.app",
|
||||
"oauth@onyx.app",
|
||||
},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
@@ -73,13 +80,29 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
id="All_Confluence_Users_Found_By_Onyx",
|
||||
user_emails={
|
||||
"chris@onyx.app",
|
||||
"hagen@danswer.ai",
|
||||
"founders@onyx.app",
|
||||
"hagen@danswer.ai",
|
||||
"pablo@onyx.app",
|
||||
"yuhong@onyx.app",
|
||||
"oauth@onyx.app",
|
||||
},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-users-onyxai",
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-admins-onyxai",
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="jira-servicemanagement-users-danswerai",
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.server import FunctionTool
|
||||
|
||||
@@ -28,4 +30,7 @@ def make_many_tools() -> list[FunctionTool]:
|
||||
if __name__ == "__main__":
|
||||
# Streamable HTTP transport (recommended)
|
||||
make_many_tools()
|
||||
mcp.run(transport="http", host="127.0.0.1", port=8000, path="/mcp")
|
||||
host = os.getenv("MCP_SERVER_BIND_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("MCP_SERVER_PORT", "8000"))
|
||||
path = os.getenv("MCP_SERVER_PATH", "/mcp")
|
||||
mcp.run(transport="http", host=host, port=port, path=path)
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
|
||||
config = {
|
||||
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
"space": "DailyConnectorTestSpace",
|
||||
"space": "DailyConne",
|
||||
"is_cloud": True,
|
||||
}
|
||||
|
||||
|
||||
148
backend/tests/integration/tests/mcp/test_mcp_no_auth_flow.py
Normal file
148
backend/tests/integration/tests/mcp/test_mcp_no_auth_flow.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
MCP_SERVER_HOST = os.getenv("TEST_WEB_HOSTNAME", "127.0.0.1")
|
||||
MCP_SERVER_PORT = int(os.getenv("MCP_SERVER_PORT", "8000"))
|
||||
MCP_SERVER_URL = f"http://{MCP_SERVER_HOST}:{MCP_SERVER_PORT}/mcp"
|
||||
MCP_HELLO_TOOL = "hello"
|
||||
|
||||
MCP_SERVER_SCRIPT = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "mock_services"
|
||||
/ "mcp_test_server"
|
||||
/ "run_mcp_server_no_auth.py"
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_port(
|
||||
host: str,
|
||||
port: int,
|
||||
process: subprocess.Popen[bytes],
|
||||
timeout_seconds: float = 10.0,
|
||||
) -> None:
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout_seconds:
|
||||
if process.poll() is not None:
|
||||
raise RuntimeError("MCP server process exited unexpectedly during startup")
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(0.5)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
return
|
||||
except OSError:
|
||||
time.sleep(0.1)
|
||||
|
||||
raise TimeoutError("Timed out waiting for MCP server to accept connections")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_no_auth_server() -> Generator[None, None, None]:
|
||||
process = subprocess.Popen(
|
||||
[sys.executable, str(MCP_SERVER_SCRIPT)],
|
||||
cwd=MCP_SERVER_SCRIPT.parent,
|
||||
)
|
||||
|
||||
try:
|
||||
_wait_for_port(MCP_SERVER_HOST, MCP_SERVER_PORT, process)
|
||||
yield
|
||||
finally:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def ensure_mcp_server_exists() -> None:
|
||||
if not MCP_SERVER_SCRIPT.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Expected MCP server script at {MCP_SERVER_SCRIPT}, but it was not found"
|
||||
)
|
||||
|
||||
|
||||
def test_mcp_no_auth_flow(
|
||||
mcp_no_auth_server: None,
|
||||
reset: None,
|
||||
admin_user: DATestUser,
|
||||
basic_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
# Step a) Create a no-auth MCP server via the admin API
|
||||
create_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/mcp/servers/create",
|
||||
json={
|
||||
"name": "integration-mcp-no-auth",
|
||||
"description": "Integration test MCP server",
|
||||
"server_url": MCP_SERVER_URL,
|
||||
"transport": MCPTransport.STREAMABLE_HTTP.value,
|
||||
"auth_type": MCPAuthenticationType.NONE.value,
|
||||
"auth_performer": MCPAuthenticationPerformer.ADMIN.value,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
create_response.raise_for_status()
|
||||
server_id = create_response.json()["server_id"]
|
||||
|
||||
# Step b) Attach the "hello" tool from the MCP server
|
||||
update_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/mcp/servers/update",
|
||||
json={
|
||||
"server_id": server_id,
|
||||
"selected_tools": [MCP_HELLO_TOOL],
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
update_response.raise_for_status()
|
||||
assert update_response.json()["updated_tools"] >= 1
|
||||
|
||||
tools_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/mcp/server/{server_id}/db-tools",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
tools_response.raise_for_status()
|
||||
tool_entries = tools_response.json()["tools"]
|
||||
hello_tool_entry = next(
|
||||
tool for tool in tool_entries if tool["name"] == MCP_HELLO_TOOL
|
||||
)
|
||||
tool_id = hello_tool_entry["id"]
|
||||
|
||||
# Step c) Create an assistant (persona) with the MCP tool attached
|
||||
persona = PersonaManager.create(
|
||||
name="integration-mcp-persona",
|
||||
description="Persona for MCP integration test",
|
||||
tool_ids=[tool_id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
persona_tools_response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
headers=basic_user.headers,
|
||||
cookies=basic_user.cookies,
|
||||
)
|
||||
persona_tools_response.raise_for_status()
|
||||
persona_entries = persona_tools_response.json()
|
||||
persona_entry = next(
|
||||
entry for entry in persona_entries if entry["id"] == persona.id
|
||||
)
|
||||
persona_tool_ids = {tool["id"] for tool in persona_entry["tools"]}
|
||||
assert tool_id in persona_tool_ids
|
||||
@@ -109,7 +109,9 @@ def test_load_credentials(jira_connector: JiraConnector) -> None:
|
||||
result = jira_connector.load_credentials(credentials)
|
||||
|
||||
mock_build_client.assert_called_once_with(
|
||||
credentials=credentials, jira_base=jira_connector.jira_base
|
||||
credentials=credentials,
|
||||
jira_base=jira_connector.jira_base,
|
||||
scoped_token=False,
|
||||
)
|
||||
assert result is None
|
||||
assert jira_connector._jira_client == mock_build_client.return_value
|
||||
@@ -226,7 +228,7 @@ def test_load_from_checkpoint_with_issue_processing_error(
|
||||
|
||||
# Mock process_jira_issue to succeed for some issues and fail for others
|
||||
def mock_process_side_effect(
|
||||
jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any
|
||||
jira_base_url: str, issue: Issue, *args: Any, **kwargs: Any
|
||||
) -> Document | None:
|
||||
if issue.key in ["TEST-1", "TEST-3"]:
|
||||
return Document(
|
||||
|
||||
@@ -92,7 +92,7 @@ def test_fetch_jira_issues_batch_small_ticket(
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [process_jira_issue("test.com", issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 1
|
||||
@@ -117,7 +117,7 @@ def test_fetch_jira_issues_batch_large_ticket(
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [process_jira_issue("test.com", issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
@@ -136,7 +136,7 @@ def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [process_jira_issue("test.com", issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
@@ -159,7 +159,7 @@ def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [process_jira_issue("test.com", issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
|
||||
269
backend/tests/unit/onyx/connectors/test_connector_factory.py
Normal file
269
backend/tests/unit/onyx/connectors/test_connector_factory.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Unit tests for lazy loading connector factory to validate:
|
||||
1. All connector mappings are correct
|
||||
2. Module paths and class names are valid
|
||||
3. Error handling works properly
|
||||
4. Caching functions correctly
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.factory import _connector_cache
|
||||
from onyx.connectors.factory import _load_connector_class
|
||||
from onyx.connectors.factory import ConnectorMissingException
|
||||
from onyx.connectors.factory import identify_connector_class
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
|
||||
from onyx.connectors.registry import ConnectorMapping
|
||||
|
||||
|
||||
class TestConnectorMappingValidation:
|
||||
"""Test that all connector mappings are valid."""
|
||||
|
||||
def test_all_connector_mappings_exist(self) -> None:
|
||||
"""Test that all mapped modules and classes actually exist."""
|
||||
errors = []
|
||||
|
||||
for source, mapping in CONNECTOR_CLASS_MAP.items():
|
||||
try:
|
||||
# Try to import the module
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
|
||||
# Try to get the class
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
|
||||
# Verify it's a subclass of BaseConnector
|
||||
if not issubclass(connector_class, BaseConnector):
|
||||
errors.append(
|
||||
f"{source.value}: {mapping.class_name} is not a BaseConnector subclass"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
errors.append(
|
||||
f"{source.value}: Failed to import {mapping.module_path} - {e}"
|
||||
)
|
||||
except AttributeError as e:
|
||||
errors.append(
|
||||
f"{source.value}: Class {mapping.class_name} not found in {mapping.module_path} - {e}"
|
||||
)
|
||||
|
||||
if errors:
|
||||
pytest.fail("Connector mapping validation failed:\n" + "\n".join(errors))
|
||||
|
||||
def test_no_duplicate_mappings(self) -> None:
|
||||
"""Test that each DocumentSource only appears once in the mapping."""
|
||||
sources = list(CONNECTOR_CLASS_MAP.keys())
|
||||
unique_sources = set(sources)
|
||||
|
||||
assert len(sources) == len(
|
||||
unique_sources
|
||||
), "Duplicate DocumentSource entries found"
|
||||
|
||||
def test_blob_storage_connectors_correct(self) -> None:
|
||||
"""Test that all blob storage sources map to the same connector."""
|
||||
blob_sources = [
|
||||
DocumentSource.S3,
|
||||
DocumentSource.R2,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE,
|
||||
DocumentSource.OCI_STORAGE,
|
||||
]
|
||||
|
||||
expected_mapping = ConnectorMapping(
|
||||
module_path="onyx.connectors.blob.connector",
|
||||
class_name="BlobStorageConnector",
|
||||
)
|
||||
|
||||
for source in blob_sources:
|
||||
assert (
|
||||
CONNECTOR_CLASS_MAP[source] == expected_mapping
|
||||
), f"{source.value} should map to BlobStorageConnector"
|
||||
|
||||
|
||||
class TestConnectorClassLoading:
|
||||
"""Test the lazy loading mechanism."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Clear cache before each test."""
|
||||
_connector_cache.clear()
|
||||
|
||||
def test_load_connector_class_success(self) -> None:
|
||||
"""Test successful connector class loading."""
|
||||
# Use a simple connector that should always exist
|
||||
connector_class = _load_connector_class(DocumentSource.WEB)
|
||||
|
||||
assert connector_class is not None
|
||||
assert issubclass(connector_class, BaseConnector)
|
||||
assert connector_class.__name__ == "WebConnector"
|
||||
|
||||
def test_load_connector_class_caching(self) -> None:
|
||||
"""Test that connector classes are cached after first load."""
|
||||
assert len(_connector_cache) == 0
|
||||
|
||||
# Load connector first time
|
||||
connector_class1 = _load_connector_class(DocumentSource.WEB)
|
||||
assert len(_connector_cache) == 1
|
||||
assert DocumentSource.WEB in _connector_cache
|
||||
|
||||
# Load same connector second time - should use cache
|
||||
connector_class2 = _load_connector_class(DocumentSource.WEB)
|
||||
assert connector_class1 is connector_class2 # Same object reference
|
||||
assert len(_connector_cache) == 1 # Cache size unchanged
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_load_connector_class_import_error(self, mock_import: Mock) -> None:
|
||||
"""Test handling of import errors."""
|
||||
mock_import.side_effect = ImportError("Module not found")
|
||||
|
||||
with pytest.raises(ConnectorMissingException) as exc_info:
|
||||
_load_connector_class(DocumentSource.WEB)
|
||||
|
||||
assert (
|
||||
"Failed to import WebConnector from onyx.connectors.web.connector"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_load_connector_class_attribute_error(self, mock_import: Mock) -> None:
|
||||
"""Test handling of missing class in module."""
|
||||
|
||||
# Create a custom mock that raises AttributeError for the specific class
|
||||
class MockModule:
|
||||
def __getattr__(self, name: str) -> MagicMock:
|
||||
if name == "WebConnector":
|
||||
raise AttributeError("Class not found")
|
||||
return MagicMock()
|
||||
|
||||
mock_import.return_value = MockModule()
|
||||
|
||||
with pytest.raises(ConnectorMissingException) as exc_info:
|
||||
_load_connector_class(DocumentSource.WEB)
|
||||
|
||||
assert (
|
||||
"Failed to import WebConnector from onyx.connectors.web.connector"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
|
||||
class TestIdentifyConnectorClass:
|
||||
"""Test the identify_connector_class function."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Clear cache before each test."""
|
||||
_connector_cache.clear()
|
||||
|
||||
def test_identify_connector_basic(self) -> None:
|
||||
"""Test basic connector identification."""
|
||||
connector_class = identify_connector_class(
|
||||
DocumentSource.GITHUB, InputType.SLIM_RETRIEVAL
|
||||
)
|
||||
|
||||
assert connector_class is not None
|
||||
assert issubclass(connector_class, BaseConnector)
|
||||
assert connector_class.__name__ == "GithubConnector"
|
||||
|
||||
def test_identify_connector_slack_special_case(self) -> None:
|
||||
"""Test Slack connector special handling."""
|
||||
# Test POLL input type
|
||||
slack_poll = identify_connector_class(DocumentSource.SLACK, InputType.POLL)
|
||||
assert slack_poll.__name__ == "SlackConnector"
|
||||
|
||||
# Test SLIM_RETRIEVAL input type
|
||||
slack_slim = identify_connector_class(
|
||||
DocumentSource.SLACK, InputType.SLIM_RETRIEVAL
|
||||
)
|
||||
assert slack_slim.__name__ == "SlackConnector"
|
||||
|
||||
# Should be the same class
|
||||
assert slack_poll is slack_slim
|
||||
|
||||
def test_identify_connector_without_input_type(self) -> None:
|
||||
"""Test connector identification without specifying input type."""
|
||||
connector_class = identify_connector_class(DocumentSource.GITHUB)
|
||||
|
||||
assert connector_class is not None
|
||||
assert connector_class.__name__ == "GithubConnector"
|
||||
|
||||
|
||||
class TestConnectorMappingIntegrity:
|
||||
"""Test integrity of the connector mapping data."""
|
||||
|
||||
def test_all_document_sources_mapped(self) -> None:
|
||||
"""Test that all DocumentSource values have mappings (where expected)."""
|
||||
# Get all DocumentSource enum values
|
||||
all_sources = set(DocumentSource)
|
||||
mapped_sources = set(CONNECTOR_CLASS_MAP.keys())
|
||||
|
||||
expected_unmapped = {
|
||||
DocumentSource.INGESTION_API, # This is handled differently
|
||||
DocumentSource.REQUESTTRACKER, # Not yet implemented or special case
|
||||
DocumentSource.NOT_APPLICABLE, # Special placeholder, no connector needed
|
||||
DocumentSource.USER_FILE, # Special placeholder, no connector needed
|
||||
# Add other legitimately unmapped sources here if they exist
|
||||
}
|
||||
|
||||
unmapped_sources = all_sources - mapped_sources - expected_unmapped
|
||||
|
||||
if unmapped_sources:
|
||||
pytest.fail(
|
||||
f"DocumentSource values without connector mappings: "
|
||||
f"{[s.value for s in unmapped_sources]}"
|
||||
)
|
||||
|
||||
def test_mapping_format_consistency(self) -> None:
|
||||
"""Test that all mappings follow the expected format."""
|
||||
for source, mapping in CONNECTOR_CLASS_MAP.items():
|
||||
assert isinstance(
|
||||
mapping, ConnectorMapping
|
||||
), f"{source.value} mapping is not a ConnectorMapping"
|
||||
|
||||
assert isinstance(
|
||||
mapping.module_path, str
|
||||
), f"{source.value} module_path is not a string"
|
||||
assert isinstance(
|
||||
mapping.class_name, str
|
||||
), f"{source.value} class_name is not a string"
|
||||
assert mapping.module_path.startswith(
|
||||
"onyx.connectors."
|
||||
), f"{source.value} module_path doesn't start with onyx.connectors."
|
||||
assert mapping.class_name.endswith(
|
||||
"Connector"
|
||||
), f"{source.value} class_name doesn't end with Connector"
|
||||
|
||||
|
||||
class TestInstantiateConnectorIntegration:
|
||||
"""Test that the lazy loading works with the main instantiate_connector function."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Clear cache before each test."""
|
||||
_connector_cache.clear()
|
||||
|
||||
def test_instantiate_connector_loads_class_lazily(self) -> None:
|
||||
"""Test that instantiate_connector triggers lazy loading."""
|
||||
# Mock the database session and credential
|
||||
mock_session = MagicMock()
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.id = 123
|
||||
mock_credential.credential_json = {"test": "data"}
|
||||
|
||||
# This should trigger lazy loading but will fail on actual instantiation
|
||||
# due to missing real configuration - that's expected
|
||||
with pytest.raises(Exception): # We expect some kind of error due to mock data
|
||||
instantiate_connector(
|
||||
mock_session,
|
||||
DocumentSource.WEB, # Simple connector
|
||||
InputType.SLIM_RETRIEVAL,
|
||||
{}, # Empty config
|
||||
mock_credential,
|
||||
)
|
||||
|
||||
# But the class should have been loaded into cache
|
||||
assert DocumentSource.WEB in _connector_cache
|
||||
assert _connector_cache[DocumentSource.WEB].__name__ == "WebConnector"
|
||||
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Unit tests for federated connector lazy loading factory to validate:
|
||||
1. All federated connector mappings are correct
|
||||
2. Module paths and class names are valid
|
||||
3. Error handling works properly
|
||||
4. Caching functions correctly
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.federated_connectors.factory import _federated_connector_cache
|
||||
from onyx.federated_connectors.factory import _load_federated_connector_class
|
||||
from onyx.federated_connectors.factory import FederatedConnectorMissingException
|
||||
from onyx.federated_connectors.factory import get_federated_connector_cls
|
||||
from onyx.federated_connectors.interfaces import FederatedConnector
|
||||
from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP
|
||||
from onyx.federated_connectors.registry import FederatedConnectorMapping
|
||||
|
||||
|
||||
class TestFederatedConnectorMappingValidation:
|
||||
"""Test that all federated connector mappings are valid."""
|
||||
|
||||
def test_all_federated_connector_mappings_exist(self) -> None:
|
||||
"""Test that all mapped modules and classes actually exist."""
|
||||
errors = []
|
||||
|
||||
for source, mapping in FEDERATED_CONNECTOR_CLASS_MAP.items():
|
||||
try:
|
||||
# Try to import the module
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
|
||||
# Try to get the class
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
|
||||
# Verify it's a subclass of FederatedConnector
|
||||
if not issubclass(connector_class, FederatedConnector):
|
||||
errors.append(
|
||||
f"{source.value}: {mapping.class_name} is not a FederatedConnector subclass"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
errors.append(
|
||||
f"{source.value}: Failed to import {mapping.module_path} - {e}"
|
||||
)
|
||||
except AttributeError as e:
|
||||
errors.append(
|
||||
f"{source.value}: Class {mapping.class_name} not found in {mapping.module_path} - {e}"
|
||||
)
|
||||
|
||||
if errors:
|
||||
pytest.fail(
|
||||
"Federated connector mapping validation failed:\n" + "\n".join(errors)
|
||||
)
|
||||
|
||||
def test_no_duplicate_mappings(self) -> None:
|
||||
"""Test that each FederatedConnectorSource only appears once in the mapping."""
|
||||
sources = list(FEDERATED_CONNECTOR_CLASS_MAP.keys())
|
||||
unique_sources = set(sources)
|
||||
|
||||
assert len(sources) == len(
|
||||
unique_sources
|
||||
), "Duplicate FederatedConnectorSource entries found"
|
||||
|
||||
def test_mapping_format_consistency(self) -> None:
|
||||
"""Test that all mappings follow the expected format."""
|
||||
for source, mapping in FEDERATED_CONNECTOR_CLASS_MAP.items():
|
||||
assert isinstance(
|
||||
mapping, FederatedConnectorMapping
|
||||
), f"{source.value} mapping is not a FederatedConnectorMapping"
|
||||
|
||||
assert isinstance(
|
||||
mapping.module_path, str
|
||||
), f"{source.value} module_path is not a string"
|
||||
assert isinstance(
|
||||
mapping.class_name, str
|
||||
), f"{source.value} class_name is not a string"
|
||||
assert mapping.module_path.startswith(
|
||||
"onyx.federated_connectors."
|
||||
), f"{source.value} module_path doesn't start with onyx.federated_connectors."
|
||||
assert mapping.class_name.endswith(
|
||||
"FederatedConnector"
|
||||
), f"{source.value} class_name doesn't end with FederatedConnector"
|
||||
|
||||
|
||||
class TestFederatedConnectorClassLoading:
|
||||
"""Test the lazy loading mechanism."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Clear cache before each test."""
|
||||
_federated_connector_cache.clear()
|
||||
|
||||
def test_load_federated_connector_class_success(self) -> None:
|
||||
"""Test successful federated connector class loading."""
|
||||
connector_class = _load_federated_connector_class(
|
||||
FederatedConnectorSource.FEDERATED_SLACK
|
||||
)
|
||||
|
||||
assert connector_class is not None
|
||||
assert issubclass(connector_class, FederatedConnector)
|
||||
assert connector_class.__name__ == "SlackFederatedConnector"
|
||||
|
||||
def test_load_federated_connector_class_caching(self) -> None:
|
||||
"""Test that federated connector classes are cached after first load."""
|
||||
assert len(_federated_connector_cache) == 0
|
||||
|
||||
# Load connector first time
|
||||
connector_class1 = _load_federated_connector_class(
|
||||
FederatedConnectorSource.FEDERATED_SLACK
|
||||
)
|
||||
assert len(_federated_connector_cache) == 1
|
||||
assert FederatedConnectorSource.FEDERATED_SLACK in _federated_connector_cache
|
||||
|
||||
# Load same connector second time - should use cache
|
||||
connector_class2 = _load_federated_connector_class(
|
||||
FederatedConnectorSource.FEDERATED_SLACK
|
||||
)
|
||||
assert connector_class1 is connector_class2 # Same object reference
|
||||
assert len(_federated_connector_cache) == 1 # Cache size unchanged
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_load_federated_connector_class_import_error(
|
||||
self, mock_import: Mock
|
||||
) -> None:
|
||||
"""Test handling of import errors."""
|
||||
mock_import.side_effect = ImportError("Module not found")
|
||||
|
||||
with pytest.raises(FederatedConnectorMissingException) as exc_info:
|
||||
_load_federated_connector_class(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
|
||||
assert (
|
||||
"Failed to import SlackFederatedConnector from onyx.federated_connectors.slack.federated_connector"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
@patch("importlib.import_module")
|
||||
def test_load_federated_connector_class_attribute_error(
|
||||
self, mock_import: Mock
|
||||
) -> None:
|
||||
"""Test handling of missing class in module."""
|
||||
|
||||
# Create a custom mock that raises AttributeError for the specific class
|
||||
class MockModule:
|
||||
def __getattr__(self, name: str) -> MagicMock:
|
||||
if name == "SlackFederatedConnector":
|
||||
raise AttributeError("Class not found")
|
||||
return MagicMock()
|
||||
|
||||
mock_import.return_value = MockModule()
|
||||
|
||||
with pytest.raises(FederatedConnectorMissingException) as exc_info:
|
||||
_load_federated_connector_class(FederatedConnectorSource.FEDERATED_SLACK)
|
||||
|
||||
assert (
|
||||
"Failed to import SlackFederatedConnector from onyx.federated_connectors.slack.federated_connector"
|
||||
in str(exc_info.value)
|
||||
)
|
||||
|
||||
|
||||
class TestGetFederatedConnectorCls:
|
||||
"""Test the get_federated_connector_cls function."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Clear cache before each test."""
|
||||
_federated_connector_cache.clear()
|
||||
|
||||
def test_get_federated_connector_cls_basic(self) -> None:
|
||||
"""Test basic federated connector class retrieval."""
|
||||
connector_class = get_federated_connector_cls(
|
||||
FederatedConnectorSource.FEDERATED_SLACK
|
||||
)
|
||||
|
||||
assert connector_class is not None
|
||||
assert issubclass(connector_class, FederatedConnector)
|
||||
assert connector_class.__name__ == "SlackFederatedConnector"
|
||||
|
||||
|
||||
class TestFederatedConnectorMappingIntegrity:
|
||||
"""Test integrity of the federated connector mapping data."""
|
||||
|
||||
def test_all_federated_connector_sources_mapped(self) -> None:
|
||||
"""Test that all FederatedConnectorSource values have mappings."""
|
||||
# Get all FederatedConnectorSource enum values
|
||||
all_sources = set(FederatedConnectorSource)
|
||||
mapped_sources = set(FEDERATED_CONNECTOR_CLASS_MAP.keys())
|
||||
|
||||
unmapped_sources = all_sources - mapped_sources
|
||||
|
||||
if unmapped_sources:
|
||||
pytest.fail(
|
||||
f"FederatedConnectorSource values without connector mappings: "
|
||||
f"{[s.value for s in unmapped_sources]}"
|
||||
)
|
||||
@@ -45,7 +45,7 @@ def default_multi_llm() -> DefaultMultiLLM:
|
||||
|
||||
def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
|
||||
# Mock the litellm.completion function
|
||||
with patch("onyx.llm.chat_llm.litellm.completion") as mock_completion:
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create a mock response with multiple tool calls using litellm objects
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
@@ -158,7 +158,7 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
|
||||
|
||||
def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> None:
|
||||
# Mock the litellm.completion function
|
||||
with patch("onyx.llm.chat_llm.litellm.completion") as mock_completion:
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create a mock response with multiple tool calls using litellm objects
|
||||
mock_response = [
|
||||
litellm.ModelResponse(
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
from scripts.check_lazy_imports import EagerImportResult
|
||||
from scripts.check_lazy_imports import find_eager_imports
|
||||
from scripts.check_lazy_imports import find_python_files
|
||||
from scripts.check_lazy_imports import LazyImportSettings
|
||||
from scripts.check_lazy_imports import main
|
||||
|
||||
|
||||
@@ -196,6 +197,72 @@ def test_find_eager_imports_file_read_error() -> None:
|
||||
assert result.violated_modules == set()
|
||||
|
||||
|
||||
def test_litellm_singleton_eager_import_detection() -> None:
|
||||
"""Test detection of eager import of litellm_singleton module."""
|
||||
test_content = """
|
||||
import os
|
||||
from onyx.llm.litellm_singleton import litellm # Should be flagged as eager import
|
||||
from typing import Dict
|
||||
|
||||
def some_function():
|
||||
# This would be OK - lazy import
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
return litellm.some_method()
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"onyx.llm.litellm_singleton"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find one violation (line 3)
|
||||
assert len(result.violation_lines) == 1
|
||||
assert result.violated_modules == {"onyx.llm.litellm_singleton"}
|
||||
|
||||
line_num, line = result.violation_lines[0]
|
||||
assert "from onyx.llm.litellm_singleton import litellm" in line
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_litellm_singleton_lazy_import_ok() -> None:
|
||||
"""Test that lazy import of litellm_singleton is allowed."""
|
||||
test_content = """
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
def get_litellm():
|
||||
# This is OK - lazy import inside function
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
return litellm
|
||||
|
||||
class SomeClass:
|
||||
def method(self):
|
||||
# Also OK - lazy import inside method
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
return litellm.completion()
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"onyx.llm.litellm_singleton"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find no violations
|
||||
assert len(result.violation_lines) == 0
|
||||
assert result.violated_modules == set()
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_return_type() -> None:
|
||||
"""Test that function returns correct EagerImportResult type."""
|
||||
test_content = """
|
||||
@@ -254,7 +321,13 @@ def use_nltk():
|
||||
|
||||
with patch("scripts.check_lazy_imports.__file__", str(script_path)):
|
||||
# Should not raise an exception since all imports are inside functions
|
||||
main({"vertexai", "playwright", "nltk"})
|
||||
main(
|
||||
{
|
||||
"vertexai": LazyImportSettings(),
|
||||
"playwright": LazyImportSettings(),
|
||||
"nltk": LazyImportSettings(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_main_function_with_violations(tmp_path: Path) -> None:
|
||||
@@ -283,7 +356,13 @@ from playwright.sync_api import sync_playwright
|
||||
RuntimeError,
|
||||
match="Found eager imports of .+\\. You must import them only when needed",
|
||||
):
|
||||
main({"vertexai", "playwright", "nltk"})
|
||||
main(
|
||||
{
|
||||
"vertexai": LazyImportSettings(),
|
||||
"playwright": LazyImportSettings(),
|
||||
"nltk": LazyImportSettings(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_main_function_specific_modules_only() -> None:
|
||||
@@ -396,8 +475,8 @@ def test_find_python_files_basic() -> None:
|
||||
tests_dir.mkdir()
|
||||
(tests_dir / "test_something.py").write_text("import os") # Should be excluded
|
||||
|
||||
# Test with no ignore directories
|
||||
files = find_python_files(backend_dir, set())
|
||||
# Test - find_python_files no longer takes ignore directories parameter
|
||||
files = find_python_files(backend_dir)
|
||||
file_names = [f.name for f in files]
|
||||
|
||||
assert "normal.py" in file_names
|
||||
@@ -409,58 +488,58 @@ def test_find_python_files_basic() -> None:
|
||||
) # One in root, one in subdir
|
||||
|
||||
|
||||
def test_find_python_files_ignore_directories() -> None:
|
||||
"""Test finding Python files with ignored directories."""
|
||||
def test_find_python_files_ignore_venv_directories() -> None:
|
||||
"""Test that find_python_files automatically ignores virtual environment directories."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_dir = Path(tmp_dir)
|
||||
|
||||
# Create files in various directories
|
||||
(backend_dir / "normal.py").write_text("import os")
|
||||
|
||||
model_server_dir = backend_dir / "model_server"
|
||||
model_server_dir.mkdir()
|
||||
(model_server_dir / "model.py").write_text(
|
||||
# Create venv directory (should be automatically ignored)
|
||||
venv_dir = backend_dir / "venv"
|
||||
venv_dir.mkdir()
|
||||
(venv_dir / "venv_file.py").write_text(
|
||||
"import transformers"
|
||||
) # Should be excluded
|
||||
|
||||
ignored_dir = backend_dir / "ignored"
|
||||
ignored_dir.mkdir()
|
||||
(ignored_dir / "should_be_ignored.py").write_text(
|
||||
# Create .venv directory (should be automatically ignored)
|
||||
dot_venv_dir = backend_dir / ".venv"
|
||||
dot_venv_dir.mkdir()
|
||||
(dot_venv_dir / "should_be_ignored.py").write_text(
|
||||
"import vertexai"
|
||||
) # Should be excluded
|
||||
|
||||
# Create a file with ignored directory name in filename (should be included)
|
||||
(backend_dir / "model_server_utils.py").write_text("import os")
|
||||
# Create a file with venv in filename (should be included)
|
||||
(backend_dir / "venv_utils.py").write_text("import os")
|
||||
|
||||
# Test with ignore directories
|
||||
files = find_python_files(backend_dir, {"model_server", "ignored"})
|
||||
# Test - venv directories are automatically ignored
|
||||
files = find_python_files(backend_dir)
|
||||
file_names = [f.name for f in files]
|
||||
|
||||
assert "normal.py" in file_names
|
||||
assert (
|
||||
"model.py" not in file_names
|
||||
) # Excluded because in model_server directory
|
||||
assert "venv_file.py" not in file_names # Excluded because in venv directory
|
||||
assert (
|
||||
"should_be_ignored.py" not in file_names
|
||||
) # Excluded because in ignored directory
|
||||
) # Excluded because in .venv directory
|
||||
assert (
|
||||
"model_server_utils.py" in file_names
|
||||
"venv_utils.py" in file_names
|
||||
) # Included because not in directory, just filename
|
||||
|
||||
|
||||
def test_find_python_files_nested_ignore() -> None:
|
||||
"""Test that ignored directories work with nested paths."""
|
||||
def test_find_python_files_nested_venv() -> None:
|
||||
"""Test that venv directories are ignored even when nested."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_dir = Path(tmp_dir)
|
||||
|
||||
# Create nested structure
|
||||
nested_path = backend_dir / "some" / "path" / "model_server" / "nested"
|
||||
# Create nested structure with venv
|
||||
nested_path = backend_dir / "some" / "path" / "venv" / "nested"
|
||||
nested_path.mkdir(parents=True)
|
||||
(nested_path / "deep_model.py").write_text("import transformers")
|
||||
(nested_path / "deep_venv.py").write_text("import transformers")
|
||||
|
||||
files = find_python_files(backend_dir, {"model_server"})
|
||||
files = find_python_files(backend_dir)
|
||||
|
||||
# Should exclude the deeply nested file
|
||||
# Should exclude the deeply nested file in venv
|
||||
assert len(files) == 0
|
||||
|
||||
|
||||
|
||||
5
ct.yaml
5
ct.yaml
@@ -7,7 +7,10 @@ chart-dirs:
|
||||
# must be kept in sync with Chart.yaml
|
||||
chart-repos:
|
||||
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
- ingress-nginx=https://kubernetes.github.io/ingress-nginx
|
||||
- postgresql=https://cloudnative-pg.github.io/charts
|
||||
- redis=https://ot-container-kit.github.io/helm-charts
|
||||
- minio=https://charts.min.io/
|
||||
|
||||
# have seen postgres take 10 min to pull ... so 15 min seems like a good timeout?
|
||||
helm-extra-args: --debug --timeout 900s
|
||||
|
||||
@@ -104,6 +104,9 @@ services:
|
||||
command: -c 'max_connections=250'
|
||||
restart: unless-stopped
|
||||
# POSTGRES_USER and POSTGRES_PASSWORD should be set in .env file
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
logging:
|
||||
|
||||
@@ -182,6 +182,9 @@ services:
|
||||
command: -c 'max_connections=250'
|
||||
restart: unless-stopped
|
||||
# POSTGRES_USER and POSTGRES_PASSWORD should be set in .env file
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
logging:
|
||||
|
||||
@@ -136,6 +136,9 @@ services:
|
||||
command: -c 'max_connections=250'
|
||||
restart: unless-stopped
|
||||
# POSTGRES_USER and POSTGRES_PASSWORD should be set in .env file
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
logging:
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
dependencies:
|
||||
- name: postgresql
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 14.3.1
|
||||
- name: cloudnative-pg
|
||||
repository: https://cloudnative-pg.github.io/charts
|
||||
version: 0.26.0
|
||||
- name: vespa
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
version: 0.2.24
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
- name: ingress-nginx
|
||||
repository: https://kubernetes.github.io/ingress-nginx
|
||||
version: 4.13.3
|
||||
- name: redis
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 20.1.0
|
||||
repository: https://ot-container-kit.github.io/helm-charts
|
||||
version: 0.16.6
|
||||
- name: minio
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 17.0.4
|
||||
digest: sha256:dddd687525764f5698adc339a11d268b0ee9c3ca81f8d46c9e65a6bf2c21cf25
|
||||
generated: "2025-09-24T12:16:33.661608-07:00"
|
||||
repository: https://charts.min.io/
|
||||
version: 5.4.0
|
||||
digest: sha256:c5604f05ed5b7cbe4d1b2167370e1e42cb3d97c879fbccc6ef881934ba4683e2
|
||||
generated: "2025-10-03T13:48:55.815175-07:00"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.3.3
|
||||
version: 0.4.0
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
@@ -18,23 +18,25 @@ annotations:
|
||||
- name: vespa
|
||||
image: vespaengine/vespa:8.526.15
|
||||
dependencies:
|
||||
- name: postgresql
|
||||
version: 14.3.1
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
- name: cloudnative-pg
|
||||
version: 0.26.0
|
||||
repository: https://cloudnative-pg.github.io/charts
|
||||
condition: postgresql.enabled
|
||||
alias: postgresql
|
||||
- name: vespa
|
||||
version: 0.2.24
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
version: 15.14.0
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
- name: ingress-nginx
|
||||
version: 4.13.3
|
||||
repository: https://kubernetes.github.io/ingress-nginx
|
||||
condition: nginx.enabled
|
||||
alias: nginx
|
||||
- name: redis
|
||||
version: 20.1.0
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 0.16.6
|
||||
repository: https://ot-container-kit.github.io/helm-charts
|
||||
condition: redis.enabled
|
||||
- name: minio
|
||||
version: 17.0.4
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 5.4.0
|
||||
repository: https://charts.min.io/
|
||||
condition: minio.enabled
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
data:
|
||||
INTERNAL_URL: "http://{{ include "onyx.fullname" . }}-api-service:{{ .Values.api.service.port | default 8080 }}"
|
||||
{{- if .Values.postgresql.enabled }}
|
||||
POSTGRES_HOST: {{ .Release.Name }}-postgresql
|
||||
POSTGRES_HOST: {{ .Release.Name }}-postgresql-rw
|
||||
{{- end }}
|
||||
{{- if .Values.vespa.enabled }}
|
||||
VESPA_HOST: {{ .Values.vespa.name }}.{{ .Values.vespa.service.name }}.{{ .Release.Namespace }}.svc.cluster.local
|
||||
@@ -23,5 +23,5 @@ data:
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- if .Values.minio.enabled }}
|
||||
S3_ENDPOINT_URL: "http://{{ .Release.Name }}-minio:{{ .Values.minio.service.ports.api | default 9000 }}"
|
||||
S3_ENDPOINT_URL: "http://{{ .Release.Name }}-minio:{{ default 9000 .Values.minio.service.port }}"
|
||||
{{- end }}
|
||||
|
||||
@@ -3,7 +3,7 @@ kind: ConfigMap
|
||||
metadata:
|
||||
name: onyx-nginx-conf
|
||||
data:
|
||||
nginx.conf: |
|
||||
upstreams.conf: |
|
||||
upstream api_server {
|
||||
server {{ include "onyx.fullname" . }}-api-service:{{ .Values.api.service.servicePort }} fail_timeout=0;
|
||||
}
|
||||
@@ -12,33 +12,32 @@ data:
|
||||
server {{ include "onyx.fullname" . }}-webserver:{{ .Values.webserver.service.servicePort }} fail_timeout=0;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 1024;
|
||||
server_name $$DOMAIN;
|
||||
server-snippet.conf: |
|
||||
listen 1024;
|
||||
server_name $$DOMAIN;
|
||||
|
||||
client_max_body_size 5G; # Maximum upload size
|
||||
client_max_body_size 5G;
|
||||
|
||||
location ~ ^/api(.*)$ {
|
||||
rewrite ^/api(/.*)$ $1 break;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
location / {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://web_server;
|
||||
}
|
||||
location ~ ^/api(.*)$ {
|
||||
rewrite ^/api(/.*)$ $1 break;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_buffering off;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://api_server;
|
||||
}
|
||||
|
||||
location / {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://web_server;
|
||||
}
|
||||
|
||||
@@ -9,19 +9,15 @@ global:
|
||||
pullPolicy: "IfNotPresent"
|
||||
|
||||
postgresql:
|
||||
primary:
|
||||
persistence:
|
||||
enabled: true
|
||||
cluster:
|
||||
instances: 1
|
||||
storage:
|
||||
storageClass: ""
|
||||
size: 10Gi
|
||||
shmVolume:
|
||||
enabled: true
|
||||
sizeLimit: 2Gi
|
||||
enabled: true
|
||||
auth:
|
||||
existingSecret: onyx-postgresql
|
||||
secretKeys:
|
||||
# overwriting as postgres typically expects 'postgres-password'
|
||||
adminPasswordKey: postgres_password
|
||||
enableSuperuserAccess: true
|
||||
superuserSecret:
|
||||
name: onyx-postgresql # keep in sync with auth.postgresql
|
||||
|
||||
vespa:
|
||||
name: da-vespa-0
|
||||
@@ -157,21 +153,40 @@ serviceAccount:
|
||||
|
||||
nginx:
|
||||
enabled: true
|
||||
containerPorts:
|
||||
http: 1024
|
||||
extraEnvVars:
|
||||
- name: DOMAIN
|
||||
value: localhost
|
||||
service:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
http: 80
|
||||
onyx: 3000
|
||||
targetPort:
|
||||
http: http
|
||||
onyx: http
|
||||
controller:
|
||||
containerPort:
|
||||
http: 1024
|
||||
|
||||
existingServerBlockConfigmap: onyx-nginx-conf
|
||||
# Propagate DOMAIN into nginx so server_name continues to use the same env var
|
||||
extraEnvs:
|
||||
- name: DOMAIN
|
||||
value: localhost
|
||||
|
||||
config:
|
||||
# Expose DOMAIN to the nginx config and pull in our custom snippets
|
||||
main-snippet: |
|
||||
env DOMAIN;
|
||||
http-snippet: |
|
||||
include /etc/nginx/custom-snippets/upstreams.conf;
|
||||
server-snippet: |
|
||||
include /etc/nginx/custom-snippets/server-snippet.conf;
|
||||
|
||||
# Mount the existing nginx ConfigMap that holds the upstream and server snippets
|
||||
extraVolumes:
|
||||
- name: nginx-config
|
||||
configMap:
|
||||
name: onyx-nginx-conf
|
||||
extraVolumeMounts:
|
||||
- name: nginx-config
|
||||
mountPath: /etc/nginx/custom-snippets
|
||||
readOnly: true
|
||||
|
||||
service:
|
||||
type: LoadBalancer
|
||||
ports:
|
||||
http: 80
|
||||
targetPorts:
|
||||
http: http
|
||||
|
||||
webserver:
|
||||
replicaCount: 1
|
||||
@@ -707,47 +722,54 @@ celery_worker_docfetching:
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
architecture: standalone
|
||||
commonConfiguration: |-
|
||||
# Enable AOF https://redis.io/topics/persistence#append-only-file
|
||||
appendonly no
|
||||
# Disable RDB persistence, AOF persistence already enabled.
|
||||
save ""
|
||||
master:
|
||||
replicaCount: 1
|
||||
image:
|
||||
registry: docker.io
|
||||
repository: bitnami/redis
|
||||
tag: "7.4.0"
|
||||
pullPolicy: IfNotPresent
|
||||
persistence:
|
||||
enabled: false
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 6379
|
||||
auth:
|
||||
existingSecret: onyx-redis
|
||||
existingSecretPasswordKey: redis_password
|
||||
redisStandalone:
|
||||
image: quay.io/opstree/redis
|
||||
tag: v7.0.15
|
||||
imagePullPolicy: IfNotPresent
|
||||
serviceType: ClusterIP
|
||||
resources:
|
||||
requests:
|
||||
cpu: 100m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 512Mi
|
||||
# Use existing secret for Redis password
|
||||
redisSecret:
|
||||
secretName: onyx-redis
|
||||
secretKey: redis_password
|
||||
# Redis configuration
|
||||
externalConfig:
|
||||
enabled: true
|
||||
data: |
|
||||
appendonly no
|
||||
save ""
|
||||
storageSpec:
|
||||
volumeClaimTemplate:
|
||||
spec:
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
|
||||
minio:
|
||||
enabled: true
|
||||
auth:
|
||||
existingSecret: onyx-objectstorage
|
||||
rootUserSecretKey: s3_aws_access_key_id
|
||||
rootPasswordSecretKey: s3_aws_secret_access_key
|
||||
defaultBuckets: "onyx-file-store-bucket"
|
||||
mode: standalone
|
||||
replicas: 1
|
||||
drivesPerNode: 1
|
||||
existingSecret: onyx-objectstorage
|
||||
buckets:
|
||||
- name: onyx-file-store-bucket
|
||||
persistence:
|
||||
enabled: true
|
||||
size: 30Gi
|
||||
storageClass: ""
|
||||
service:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
api: 9000
|
||||
console: 9001
|
||||
port: 9000
|
||||
consoleService:
|
||||
type: ClusterIP
|
||||
ports:
|
||||
http: 9001
|
||||
port: 9001
|
||||
|
||||
ingress:
|
||||
enabled: false
|
||||
@@ -772,10 +794,13 @@ auth:
|
||||
existingSecret: ""
|
||||
# -- This defines the env var to secret map, key is always upper-cased as an env var
|
||||
secretKeys:
|
||||
POSTGRES_PASSWORD: "postgres_password"
|
||||
# CloudNativePG requires `username` and `password` keys for the superuser secret.
|
||||
POSTGRES_USER: username
|
||||
POSTGRES_PASSWORD: password
|
||||
# -- Secrets values IF existingSecret is empty. Key here must match the value in secretKeys to be used. Values will be base64 encoded in the k8s cluster.
|
||||
values:
|
||||
postgres_password: "postgres"
|
||||
username: "postgres"
|
||||
password: "postgres"
|
||||
redis:
|
||||
# -- Enable or disable this secret entirely. Will remove from env var configurations and remove any created secrets.
|
||||
enabled: true
|
||||
@@ -804,6 +829,8 @@ auth:
|
||||
values:
|
||||
s3_aws_access_key_id: "minioadmin"
|
||||
s3_aws_secret_access_key: "minioadmin"
|
||||
rootUser: "minioadmin"
|
||||
rootPassword: "minioadmin"
|
||||
oauth:
|
||||
# -- Enable or disable this secret entirely. Will remove from env var configurations and remove any created secrets.
|
||||
enabled: false
|
||||
|
||||
11
pyproject.toml
Normal file
11
pyproject.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "onyx"
|
||||
version = "0.0.0"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["backend"]
|
||||
include = ["onyx*", "tests*"]
|
||||
@@ -9,6 +9,8 @@ export default defineConfig({
|
||||
expect: {
|
||||
timeout: 15000, // 15 seconds timeout for all assertions to reduce flakiness
|
||||
},
|
||||
retries: process.env.CI ? 2 : 0, // Retry failed tests 2 times in CI, 0 locally
|
||||
workers: process.env.CI ? 2 : undefined, // Limit to 2 parallel workers in CI to reduce flakiness
|
||||
reporter: [
|
||||
["list"],
|
||||
// Warning: uncommenting the html reporter may cause the chromatic-archives
|
||||
|
||||
@@ -423,9 +423,7 @@ export function CustomLLMProviderUpdateForm({
|
||||
{!existingLlmProvider?.deployment_name && (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If not set, will use
|
||||
the Default Model configured above.`}
|
||||
subtext="The model to use for lighter flows like `LLM Chunk Filter` for this provider. If not set, will use the Default Model configured above."
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
@@ -588,9 +588,7 @@ export function LLMProviderUpdateForm({
|
||||
(llmProviderDescriptor.model_configurations.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
subtext="The model to use for lighter flows like `LLM Chunk Filter` for this provider. If not set, will use the Default Model configured above."
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
(modelConfiguration) => ({
|
||||
@@ -606,9 +604,7 @@ export function LLMProviderUpdateForm({
|
||||
) : (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
subtext="The model to use for lighter flows like `LLM Chunk Filter` for this provider. If not set, will use the Default Model configured above."
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
@@ -18,7 +18,7 @@ export function useReIndexModal(
|
||||
const [reIndexPopupVisible, setReIndexPopupVisible] = useState(false);
|
||||
|
||||
const showReIndexModal = () => {
|
||||
if (!connectorId || !credentialId || !ccPairId) {
|
||||
if (connectorId == null || credentialId == null || ccPairId == null) {
|
||||
return;
|
||||
}
|
||||
setReIndexPopupVisible(true);
|
||||
@@ -29,7 +29,7 @@ export function useReIndexModal(
|
||||
};
|
||||
|
||||
const triggerReIndex = async (fromBeginning: boolean) => {
|
||||
if (!connectorId || !credentialId || !ccPairId) {
|
||||
if (connectorId == null || credentialId == null || ccPairId == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -64,7 +64,10 @@ export function useReIndexModal(
|
||||
};
|
||||
|
||||
const FinalReIndexModal =
|
||||
reIndexPopupVisible && connectorId && credentialId && ccPairId ? (
|
||||
reIndexPopupVisible &&
|
||||
connectorId != null &&
|
||||
credentialId != null &&
|
||||
ccPairId != null ? (
|
||||
<ReIndexModal
|
||||
setPopup={setPopup}
|
||||
hide={hideReIndexModal}
|
||||
|
||||
@@ -121,8 +121,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
|
||||
// Initialize hooks at top level to avoid conditional hook calls
|
||||
const { showReIndexModal, ReIndexModal } = useReIndexModal(
|
||||
ccPair?.connector?.id || null,
|
||||
ccPair?.credential?.id || null,
|
||||
ccPair?.connector?.id ?? null,
|
||||
ccPair?.credential?.id ?? null,
|
||||
ccPairId,
|
||||
setPopup
|
||||
);
|
||||
|
||||
@@ -255,7 +255,7 @@ export function SettingsForm() {
|
||||
<Checkbox
|
||||
label="Deep Research"
|
||||
sublabel="If set, users will be able to use Deep Research."
|
||||
checked={settings.deep_research_enabled ?? false}
|
||||
checked={settings.deep_research_enabled ?? true}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField("deep_research_enabled", e.target.checked)
|
||||
}
|
||||
|
||||
@@ -95,22 +95,21 @@ export function UserDropdown({
|
||||
}
|
||||
|
||||
const handleLogout = () => {
|
||||
logout().then((isSuccess) => {
|
||||
if (!isSuccess) {
|
||||
logout().then((response) => {
|
||||
if (!response?.ok) {
|
||||
alert("Failed to logout");
|
||||
return;
|
||||
}
|
||||
|
||||
// Construct the current URL
|
||||
const currentUrl = `${pathname}${
|
||||
searchParams?.toString() ? `?${searchParams.toString()}` : ""
|
||||
}`;
|
||||
|
||||
// Encode the current URL to use as a redirect parameter
|
||||
const encodedRedirect = encodeURIComponent(currentUrl);
|
||||
|
||||
// Redirect to login page with the current page as a redirect parameter
|
||||
router.push(`/auth/login?next=${encodedRedirect}`);
|
||||
router.push(
|
||||
`/auth/login?disableAutoRedirect=true&next=${encodedRedirect}`
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -34,9 +34,6 @@ export function ToolList({
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [showToolList, setShowToolList] = useState(
|
||||
searchParams.get("listing_tools") === "true"
|
||||
);
|
||||
const [currentServerId, setCurrentServerId] = useState<number | undefined>(
|
||||
serverId
|
||||
);
|
||||
@@ -46,14 +43,13 @@ export function ToolList({
|
||||
if (
|
||||
searchParams.get("listing_tools") === "true" &&
|
||||
serverId &&
|
||||
!showToolList &&
|
||||
values.name.trim() &&
|
||||
values.server_url.trim()
|
||||
) {
|
||||
// Only auto-trigger for servers that have required form values and a serverId
|
||||
handleListActions(values);
|
||||
}
|
||||
}, [searchParams, serverId, showToolList, values.name, values.server_url]);
|
||||
}, [searchParams, serverId, values.name, values.server_url]);
|
||||
|
||||
const handleListActions = async (values: MCPFormValues) => {
|
||||
// Check if OAuth needs connection first
|
||||
@@ -113,10 +109,6 @@ export function ToolList({
|
||||
// Update serverId for subsequent operations
|
||||
newServerId = serverResult.server_id;
|
||||
setCurrentServerId(newServerId);
|
||||
// Ensure URL reflects the created server and listing state to avoid duplicate creation (409)
|
||||
router.replace(
|
||||
`/admin/actions/edit-mcp?server_id=${newServerId}&listing_tools=true`
|
||||
);
|
||||
} else {
|
||||
// For OAuth servers, use the existing serverId
|
||||
if (!serverId) {
|
||||
@@ -129,6 +121,11 @@ export function ToolList({
|
||||
}
|
||||
newServerId = serverId;
|
||||
}
|
||||
// Ensure URL reflects the created server and listing state to avoid duplicate creation
|
||||
// and set listing_tools=true so the tool list is shown
|
||||
router.replace(
|
||||
`/admin/actions/edit-mcp?server_id=${newServerId}&listing_tools=true`
|
||||
);
|
||||
|
||||
// List available tools from the saved server
|
||||
const promises: Promise<Response>[] = [
|
||||
@@ -175,7 +172,6 @@ export function ToolList({
|
||||
return;
|
||||
}
|
||||
|
||||
setShowToolList(true);
|
||||
setCurrentPage(1);
|
||||
|
||||
// Process available tools
|
||||
@@ -315,7 +311,7 @@ export function ToolList({
|
||||
}
|
||||
};
|
||||
|
||||
return !showToolList ? (
|
||||
return listingTools || searchParams.get("listing_tools") !== "true" ? (
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
@@ -490,7 +486,6 @@ export function ToolList({
|
||||
const currentUrl = new URL(window.location.href);
|
||||
currentUrl.searchParams.delete("listing_tools");
|
||||
router.replace(currentUrl.toString());
|
||||
setShowToolList(false);
|
||||
}}
|
||||
>
|
||||
Back
|
||||
|
||||
@@ -58,7 +58,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
notifications: [],
|
||||
needs_reindexing: false,
|
||||
anonymous_user_enabled: false,
|
||||
deep_research_enabled: false,
|
||||
deep_research_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
@@ -114,7 +114,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
}
|
||||
|
||||
if (settings.deep_research_enabled == null) {
|
||||
settings.deep_research_enabled = false;
|
||||
settings.deep_research_enabled = true;
|
||||
}
|
||||
|
||||
const webVersion = getWebVersion();
|
||||
|
||||
@@ -525,6 +525,14 @@ export const connectorConfigs: Record<
|
||||
description:
|
||||
"The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)",
|
||||
},
|
||||
{
|
||||
type: "checkbox",
|
||||
query: "Using scoped token?",
|
||||
label: "Using scoped token",
|
||||
name: "scoped_token",
|
||||
optional: true,
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
type: "tab",
|
||||
name: "indexing_scope",
|
||||
@@ -616,6 +624,14 @@ export const connectorConfigs: Record<
|
||||
description:
|
||||
"The base URL of your Jira instance (e.g., https://your-domain.atlassian.net)",
|
||||
},
|
||||
{
|
||||
type: "checkbox",
|
||||
query: "Using scoped token?",
|
||||
label: "Using scoped token",
|
||||
name: "scoped_token",
|
||||
optional: true,
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
type: "tab",
|
||||
name: "indexing_scope",
|
||||
|
||||
@@ -387,11 +387,12 @@ export const SOURCE_METADATA_MAP: SourceMap = {
|
||||
isPopular: true,
|
||||
},
|
||||
user_file: {
|
||||
// TODO: write docs for projects and link them here
|
||||
icon: FileIcon2,
|
||||
displayName: "File",
|
||||
category: SourceCategory.Other,
|
||||
docs: "https://docs.onyx.app/admin/connectors/official/file",
|
||||
isPopular: true,
|
||||
isPopular: false, // Needs to be false to hide from the Add Connector page
|
||||
},
|
||||
|
||||
// Other
|
||||
@@ -444,7 +445,9 @@ export function listSourceMetadata(): SourceMetadata[] {
|
||||
source !== "ingestion_api" &&
|
||||
source !== "mock_connector" &&
|
||||
// use the "regular" slack connector when listing
|
||||
source !== "federated_slack"
|
||||
source !== "federated_slack" &&
|
||||
// user_file is for internal use (projects), not the Add Connector page
|
||||
source !== "user_file"
|
||||
)
|
||||
.map(([source, metadata]) => {
|
||||
return fillSourceMetadata(metadata, source as ValidSources);
|
||||
|
||||
@@ -16,7 +16,7 @@ export const getCurrentUser = async (): Promise<User | null> => {
|
||||
};
|
||||
|
||||
export const logout = async (): Promise<Response> => {
|
||||
const response = await fetch("/api/auth/logout", {
|
||||
const response = await fetch("/auth/logout", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
@@ -3,7 +3,8 @@ import { dragElementAbove, dragElementBelow } from "../utils/dragUtils";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
import { createAssistant, pinAssistantByName } from "../utils/assistantUtils";
|
||||
|
||||
test("Assistant Drag and Drop", async ({ page }) => {
|
||||
// TODO (chris): figure out why this test is flakey
|
||||
test.skip("Assistant Drag and Drop", async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user