mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
133 Commits
v1.6.0-clo
...
unified
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
91eaf23162 | ||
|
|
bdafbfe0e8 | ||
|
|
278fd0e153 | ||
|
|
a4bb97bc22 | ||
|
|
8063d9a75e | ||
|
|
1ffaba12f0 | ||
|
|
26f8660663 | ||
|
|
d6504ed578 | ||
|
|
7fcc2c9d35 | ||
|
|
46e8f925fe | ||
|
|
5ec1f61839 | ||
|
|
df950963a7 | ||
|
|
93208a66ac | ||
|
|
a4819e07e7 | ||
|
|
f642ace40c | ||
|
|
9b430ae2d5 | ||
|
|
05f3f878b2 | ||
|
|
df17c5352e | ||
|
|
bcfb0f3cf3 | ||
|
|
38468c1dc4 | ||
|
|
8550a9c5e3 | ||
|
|
fe0c60e50d | ||
|
|
4ecc151a02 | ||
|
|
d08becead5 | ||
|
|
a429f852d5 | ||
|
|
a856f27fae | ||
|
|
d0d8027928 | ||
|
|
bd1671f1a1 | ||
|
|
e236c67678 | ||
|
|
683956697a | ||
|
|
fb1e303ffc | ||
|
|
729d4fafd1 | ||
|
|
40c60282d0 | ||
|
|
2141fd2c6e | ||
|
|
9aeba96043 | ||
|
|
b431de5141 | ||
|
|
d1a6340cfc | ||
|
|
ccf382ef4f | ||
|
|
c31997b9b2 | ||
|
|
ab31795a46 | ||
|
|
b3beca63dc | ||
|
|
cc6d54c1e6 | ||
|
|
ee12c0c5de | ||
|
|
d48912a05d | ||
|
|
c079072676 | ||
|
|
952f6bfb37 | ||
|
|
0714e4bb4e | ||
|
|
ae577f0f44 | ||
|
|
0705d584d8 | ||
|
|
36e391e557 | ||
|
|
1efce594b5 | ||
|
|
67ac53f17d | ||
|
|
d5a222925a | ||
|
|
d5ef928782 | ||
|
|
6963d78f8e | ||
|
|
d3ef2b8c17 | ||
|
|
70f4162ea8 | ||
|
|
883f52d332 | ||
|
|
f8fd83c883 | ||
|
|
d2bf0c0c5f | ||
|
|
5d598c2d22 | ||
|
|
9dc0e97302 | ||
|
|
048b2a6b39 | ||
|
|
7dd3cecf67 | ||
|
|
82abe28986 | ||
|
|
a0575e6a00 | ||
|
|
c3702b76b6 | ||
|
|
bb239d574c | ||
|
|
172e5f0e24 | ||
|
|
26b026fb88 | ||
|
|
870629e8a9 | ||
|
|
a547112321 | ||
|
|
da5a94815e | ||
|
|
e024472b74 | ||
|
|
e74855e633 | ||
|
|
e4c26a933d | ||
|
|
36c96f2d98 | ||
|
|
0c5bf5b3ed | ||
|
|
492117d910 | ||
|
|
1ea94dcd8d | ||
|
|
2b1c5a0755 | ||
|
|
82b5f806ab | ||
|
|
6340c517d1 | ||
|
|
3baae2d4f0 | ||
|
|
d7c223ddd4 | ||
|
|
df4917243b | ||
|
|
a79ab713ce | ||
|
|
d1f7cee959 | ||
|
|
a3f41e20da | ||
|
|
458ed93da0 | ||
|
|
273d073bd7 | ||
|
|
9455c8e5ae | ||
|
|
d45d4389a0 | ||
|
|
bd901c0da1 | ||
|
|
2192605c95 | ||
|
|
d248d2f4e9 | ||
|
|
331c53871a | ||
|
|
f62d0d9144 | ||
|
|
427945e757 | ||
|
|
e55cdc6250 | ||
|
|
6a01db9ff2 | ||
|
|
82e9df5c22 | ||
|
|
16c2ef2852 | ||
|
|
224a70eea9 | ||
|
|
c457982120 | ||
|
|
0649748da2 | ||
|
|
ddceddaa28 | ||
|
|
c6733a5026 | ||
|
|
7db744a5de | ||
|
|
cd2a8b0def | ||
|
|
f15bc26cd6 | ||
|
|
65f35f0293 | ||
|
|
4e3e608249 | ||
|
|
719a092a12 | ||
|
|
6a8fde7eb1 | ||
|
|
4fdd0812a0 | ||
|
|
4913dc1e85 | ||
|
|
4a43a9642e | ||
|
|
cc48a0c38e | ||
|
|
01ccfd2df7 | ||
|
|
36d75786ee | ||
|
|
f9bc38ba65 | ||
|
|
3da283221d | ||
|
|
90568d3bbb | ||
|
|
7955ca938c | ||
|
|
f5d357eb28 | ||
|
|
d83f616214 | ||
|
|
275c1bec3d | ||
|
|
7d1ef912e8 | ||
|
|
2fe1d4c373 | ||
|
|
2396ad309e | ||
|
|
0b13ef963a | ||
|
|
83073f3ded |
19
.github/actions/custom-build-and-push/action.yml
vendored
19
.github/actions/custom-build-and-push/action.yml
vendored
@@ -35,6 +35,16 @@ inputs:
|
||||
cache-to:
|
||||
description: 'Cache destinations'
|
||||
required: false
|
||||
outputs:
|
||||
description: 'Output destinations'
|
||||
required: false
|
||||
provenance:
|
||||
description: 'Generate provenance attestation'
|
||||
required: false
|
||||
default: 'false'
|
||||
build-args:
|
||||
description: 'Build arguments'
|
||||
required: false
|
||||
retry-wait-time:
|
||||
description: 'Time to wait before attempt 2 in seconds'
|
||||
required: false
|
||||
@@ -62,6 +72,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 2
|
||||
if: steps.buildx1.outcome != 'success'
|
||||
@@ -85,6 +98,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Wait before attempt 3
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success'
|
||||
@@ -108,6 +124,9 @@ runs:
|
||||
no-cache: ${{ inputs.no-cache }}
|
||||
cache-from: ${{ inputs.cache-from }}
|
||||
cache-to: ${{ inputs.cache-to }}
|
||||
outputs: ${{ inputs.outputs }}
|
||||
provenance: ${{ inputs.provenance }}
|
||||
build-args: ${{ inputs.build-args }}
|
||||
|
||||
- name: Report failure
|
||||
if: steps.buildx1.outcome != 'success' && steps.buildx2.outcome != 'success' && steps.buildx3.outcome != 'success'
|
||||
|
||||
@@ -142,15 +142,25 @@ jobs:
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# Security: Using pinned digest (0.65.0@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436)
|
||||
# Security: No Docker socket mount needed for remote registry scanning
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL onyxdotapp/onyx-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
--ignorefile /tmp/.trivyignore \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -139,12 +139,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
[runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-arm64"]
|
||||
env:
|
||||
PLATFORM_PAIR: linux-arm64
|
||||
steps:
|
||||
@@ -164,13 +164,20 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -150,12 +150,20 @@ jobs:
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
|
||||
@@ -21,6 +21,9 @@ env:
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
167
.github/workflows/pr-helm-chart-testing.yml
vendored
167
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -53,27 +53,154 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=postgresql.enabled=false \
|
||||
--set=redis.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--debug --config ct.yaml
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=redis.enabled=true \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
548
.github/workflows/pr-integration-tests.yml
vendored
548
.github/workflows/pr-integration-tests.yml
vendored
@@ -11,6 +11,12 @@ on:
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
@@ -23,18 +29,38 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -47,12 +73,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/ee.txt
|
||||
- run: |
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
working-directory: ./backend
|
||||
@@ -74,130 +100,151 @@ jobs:
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# Start containers for multi-tenant tests
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
id: run_multitenant_tests
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -210,7 +257,16 @@ jobs:
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -253,52 +309,44 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -318,7 +366,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -327,3 +375,157 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
wait
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
- name: Start Docker containers for multi-tenant tests
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
- name: Wait for service to be ready (multi-tenant)
|
||||
run: |
|
||||
echo "Starting wait-for-service script for multi-tenant..."
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e SKIP_RESET=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
|
||||
- name: Dump API server logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
|
||||
|
||||
- name: Dump all-container logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
359
.github/workflows/pr-mit-integration-tests.yml
vendored
359
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -5,12 +5,15 @@ concurrency:
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
types: [checks_requested]
|
||||
|
||||
env:
|
||||
# Private Registry Configuration
|
||||
PRIVATE_REGISTRY: experimental-registry.blacksmith.sh:5000
|
||||
PRIVATE_REGISTRY_USERNAME: ${{ secrets.PRIVATE_REGISTRY_USERNAME }}
|
||||
PRIVATE_REGISTRY_PASSWORD: ${{ secrets.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
@@ -23,21 +26,42 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
discover-test-dirs:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
prepare-build:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -46,7 +70,9 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
- run: |
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
@@ -70,71 +96,153 @@ jobs:
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Upload OpenAPI artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
- name: Download OpenAPI artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: openapi-artifacts
|
||||
path: backend/generated/
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push integration test Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
# See https://docs.blacksmith.sh/blacksmith-runners/overview
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Private Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.PRIVATE_REGISTRY }}
|
||||
username: ${{ env.PRIVATE_REGISTRY_USERNAME }}
|
||||
password: ${{ env.PRIVATE_REGISTRY_PASSWORD }}
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/backend-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/model-server-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/mit-integration-tests/integration-${{ env.PLATFORM_PAIR }}/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
# Re-tag to remove registry prefix for docker-compose
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }} onyxdotapp/onyx-integration:test
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
@@ -145,7 +253,16 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -189,51 +306,44 @@ jobs:
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
@@ -253,7 +363,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
@@ -262,3 +372,20 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const needs = ${{ toJSON(needs) }};
|
||||
const failed = Object.values(needs).some(n => n.result !== 'success');
|
||||
if (failed) {
|
||||
core.setFailed('One or more upstream jobs failed or were cancelled.');
|
||||
} else {
|
||||
core.notice('All required jobs succeeded.');
|
||||
}
|
||||
|
||||
234
.github/workflows/pr-playwright-tests.yml
vendored
234
.github/workflows/pr-playwright-tests.yml
vendored
@@ -6,44 +6,165 @@ concurrency:
|
||||
on: push
|
||||
|
||||
env:
|
||||
# AWS ECR Configuration
|
||||
AWS_REGION: ${{ secrets.AWS_REGION || 'us-west-2' }}
|
||||
ECR_REGISTRY: ${{ secrets.ECR_REGISTRY }}
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_ECR }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_ECR }}
|
||||
BUILDX_NO_DEFAULT_ATTESTATIONS: 1
|
||||
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }}
|
||||
|
||||
MOCK_LLM_RESPONSE: true
|
||||
PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
name: Playwright Tests
|
||||
build-web-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Web Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
build-backend-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
push: true
|
||||
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
aws-access-key-id: ${{ env.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ env.AWS_SECRET_ACCESS_KEY }}
|
||||
aws-region: ${{ env.AWS_REGION }}
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@v2
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
# Pull all images from ECR in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }}) &
|
||||
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
echo "All Docker images pulled successfully"
|
||||
|
||||
# Re-tag with expected names for docker-compose
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-web-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-web-server:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-backend:playwright-test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.ECR_REGISTRY }}/integration-test-onyx-model-server:playwright-test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
@@ -58,68 +179,13 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
|
||||
- name: Build Web Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-web-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
GEN_AI_API_KEY=${{ secrets.OPENAI_API_KEY }} \
|
||||
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }} \
|
||||
EXA_API_KEY=${{ env.EXA_API_KEY }} \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
@@ -160,12 +226,6 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run pytest playwright test init
|
||||
working-directory: ./backend
|
||||
env:
|
||||
PYTEST_IGNORE_SKIP: true
|
||||
run: pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
|
||||
- name: Run Playwright tests
|
||||
working-directory: ./web
|
||||
run: npx playwright test
|
||||
|
||||
8
.vscode/env_template.txt
vendored
8
.vscode/env_template.txt
vendored
@@ -69,4 +69,10 @@ S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
# Show extra/uncommon connectors
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
295
AGENTS.md
Normal file
295
AGENTS.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
295
CLAUDE.md
Normal file
295
CLAUDE.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings.
|
||||
|
||||
|
||||
### Background Workers (Celery)
|
||||
|
||||
Onyx uses Celery for asynchronous task processing with multiple specialized workers:
|
||||
|
||||
#### Worker Types
|
||||
|
||||
1. **Primary Worker** (`celery_app.py`)
|
||||
- Coordinates core background tasks and system-wide operations
|
||||
- Handles connector management, document sync, pruning, and periodic checks
|
||||
- Runs with 4 threads concurrency
|
||||
- Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync
|
||||
|
||||
2. **Docfetching Worker** (`docfetching`)
|
||||
- Fetches documents from external data sources (connectors)
|
||||
- Spawns docprocessing tasks for each document batch
|
||||
- Implements watchdog monitoring for stuck connectors
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
3. **Docprocessing Worker** (`docprocessing`)
|
||||
- Processes fetched documents through the indexing pipeline:
|
||||
- Upserts documents to PostgreSQL
|
||||
- Chunks documents and adds contextual information
|
||||
- Embeds chunks via model server
|
||||
- Writes chunks to Vespa vector database
|
||||
- Updates document metadata
|
||||
- Configurable concurrency (default from env)
|
||||
|
||||
4. **Light Worker** (`light`)
|
||||
- Handles lightweight, fast operations
|
||||
- Tasks: vespa operations, document permissions sync, external group sync
|
||||
- Higher concurrency for quick tasks
|
||||
|
||||
5. **Heavy Worker** (`heavy`)
|
||||
- Handles resource-intensive operations
|
||||
- Primary task: document pruning operations
|
||||
- Runs with 4 threads concurrency
|
||||
|
||||
6. **KG Processing Worker** (`kg_processing`)
|
||||
- Handles Knowledge Graph processing and clustering
|
||||
- Builds relationships between documents
|
||||
- Runs clustering algorithms
|
||||
- Configurable concurrency
|
||||
|
||||
7. **Monitoring Worker** (`monitoring`)
|
||||
- System health monitoring and metrics collection
|
||||
- Monitors Celery queues, process memory, and system status
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
- Indexing checks (every 15 seconds)
|
||||
- Connector deletion checks (every 20 seconds)
|
||||
- Vespa sync checks (every 20 seconds)
|
||||
- Pruning checks (every 20 seconds)
|
||||
- KG processing (every 60 seconds)
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
- **Failure Handling**: Automatic retry and failure recovery mechanisms
|
||||
- **Redis Coordination**: Inter-process communication via Redis
|
||||
- **PostgreSQL State**: Task state and metadata stored in PostgreSQL
|
||||
|
||||
|
||||
#### Important Notes
|
||||
|
||||
**Defining Tasks**:
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
function.
|
||||
|
||||
**Testing Updates**:
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
### Code Quality
|
||||
```bash
|
||||
# Install and run pre-commit hooks
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
NOTE: Always make sure everything is strictly typed (both in Python and Typescript).
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Technology Stack
|
||||
- **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery
|
||||
- **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS
|
||||
- **Database**: PostgreSQL with Redis caching
|
||||
- **Search**: Vespa vector database
|
||||
- **Auth**: OAuth2, SAML, multi-provider support
|
||||
- **AI/ML**: LangChain, LiteLLM, multiple embedding models
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
backend/
|
||||
├── onyx/
|
||||
│ ├── auth/ # Authentication & authorization
|
||||
│ ├── chat/ # Chat functionality & LLM interactions
|
||||
│ ├── connectors/ # Data source connectors
|
||||
│ ├── db/ # Database models & operations
|
||||
│ ├── document_index/ # Vespa integration
|
||||
│ ├── federated_connectors/ # External search connectors
|
||||
│ ├── llm/ # LLM provider integrations
|
||||
│ └── server/ # API endpoints & routers
|
||||
├── ee/ # Enterprise Edition features
|
||||
├── alembic/ # Database migrations
|
||||
└── tests/ # Test suites
|
||||
|
||||
web/
|
||||
├── src/app/ # Next.js app router pages
|
||||
├── src/components/ # Reusable React components
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
```bash
|
||||
# Standard migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Multi-tenant (Enterprise)
|
||||
alembic -n schema_private upgrade head
|
||||
```
|
||||
|
||||
### Creating Migrations
|
||||
```bash
|
||||
# Auto-generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Multi-tenant migration
|
||||
alembic -n schema_private revision --autogenerate -m "description"
|
||||
```
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
There are 4 main types of tests within Onyx:
|
||||
|
||||
### Unit Tests
|
||||
These should not assume any Onyx/external services are available to be called.
|
||||
Interactions with the outside world should be mocked using `unittest.mock`. Generally, only
|
||||
write these for complex, isolated modules e.g. `citation_processing.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests
|
||||
These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis,
|
||||
MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.).
|
||||
|
||||
However, the actual Onyx containers are not running and with these tests we call the function to test directly.
|
||||
We can also mock components/calls at will.
|
||||
|
||||
The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey,
|
||||
need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called
|
||||
with certain args, something that would be impossible with proper integration tests).
|
||||
|
||||
A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot
|
||||
mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal
|
||||
verification is necessary) over any other type of test.
|
||||
|
||||
Tests are parallelized at a directory level.
|
||||
|
||||
When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager
|
||||
class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright (E2E) Tests
|
||||
These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx
|
||||
running, *including* the Web Server.
|
||||
|
||||
Use these tests for anything that requires significant frontend <-> backend coordination.
|
||||
|
||||
Tests are located at `web/tests/e2e`. Tests are written in TypeScript.
|
||||
|
||||
To run them:
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
to logs via the `backend/log/<service_name>_debug.log` file. All Onyx services (api_server, web_server, celery_X)
|
||||
will be tailing their logs to this file.
|
||||
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- Never commit API keys or secrets to repository
|
||||
- Use encrypted credential storage for connector credentials
|
||||
- Follow RBAC patterns for new features
|
||||
- Implement proper input validation with Pydantic models
|
||||
- Use parameterized queries to prevent SQL injection
|
||||
|
||||
## AI/LLM Integration
|
||||
|
||||
- Multiple LLM providers supported via LiteLLM
|
||||
- Configurable models per feature (chat, search, embeddings)
|
||||
- Streaming support for real-time responses
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
**Issues to Address**
|
||||
What the change is meant to do.
|
||||
|
||||
**Important Notes**
|
||||
Things you come across in your research that are important to the implementation.
|
||||
|
||||
**Implementation strategy**
|
||||
How you are going to make the changes happen. High level approach.
|
||||
|
||||
**Tests**
|
||||
What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to
|
||||
verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test.
|
||||
|
||||
Do NOT include these: *Timeline*, *Rollback plan*
|
||||
|
||||
This is a minimal list - feel free to include more. Do NOT write code as part of your plan.
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
@@ -5,7 +5,7 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
|
||||
## Initial Setup
|
||||
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/.env.template` to `.vscode/.env`
|
||||
- Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
@@ -17,10 +17,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/deployment/getting_started/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
@@ -97,7 +97,7 @@ Keep knowledge and access up to sync across 40+ connectors:
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
See the full list [here](https://docs.onyx.app/admin/connectors/overview).
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
|
||||
@@ -12,7 +12,8 @@ ARG ONYX_VERSION=0.0.0-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DO_NOT_TRACK="true" \
|
||||
PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
@@ -23,6 +23,22 @@ RUN mkdir -p /app && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# --- add toolchain needed for Rust/Python builds (fastuuid) ---
|
||||
ENV RUSTUP_HOME=/usr/local/rustup \
|
||||
CARGO_HOME=/usr/local/cargo \
|
||||
PATH=/usr/local/cargo/bin:$PATH
|
||||
|
||||
RUN set -eux; \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
curl \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
# Install latest stable Rust (supports Cargo.lock v4)
|
||||
&& curl -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal --default-toolchain stable \
|
||||
&& rustc --version && cargo --version
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
"""merge_default_assistants_into_unified
|
||||
|
||||
Revision ID: 505c488f6662
|
||||
Revises: d09fc20a3c66
|
||||
Create Date: 2025-09-09 19:00:56.816626
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import NamedTuple
|
||||
from uuid import UUID
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "505c488f6662"
|
||||
down_revision = "d09fc20a3c66"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for the unified assistant
|
||||
UNIFIED_ASSISTANT_NAME = "Assistant"
|
||||
UNIFIED_ASSISTANT_DESCRIPTION = (
|
||||
"Your AI assistant with search, web browsing, and image generation capabilities."
|
||||
)
|
||||
UNIFIED_ASSISTANT_NUM_CHUNKS = 25
|
||||
UNIFIED_ASSISTANT_DISPLAY_PRIORITY = 0
|
||||
UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION = True
|
||||
UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER = False
|
||||
UNIFIED_ASSISTANT_RECENCY_BIAS = "AUTO" # NOTE: needs to be capitalized
|
||||
UNIFIED_ASSISTANT_CHUNKS_ABOVE = 0
|
||||
UNIFIED_ASSISTANT_CHUNKS_BELOW = 0
|
||||
UNIFIED_ASSISTANT_DATETIME_AWARE = True
|
||||
|
||||
# NOTE: tool specific prompts are handled on the fly and automatically injected
|
||||
# into the prompt before passing to the LLM.
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \
|
||||
user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \
|
||||
provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \
|
||||
prioritize being truthful, nuanced, insightful, and efficient.
|
||||
The current date is [[CURRENT_DATETIME]]
|
||||
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \
|
||||
your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \
|
||||
symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use Markdown horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".strip()
|
||||
|
||||
|
||||
INSERT_DICT: dict[str, Any] = {
|
||||
"name": UNIFIED_ASSISTANT_NAME,
|
||||
"description": UNIFIED_ASSISTANT_DESCRIPTION,
|
||||
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
||||
"num_chunks": UNIFIED_ASSISTANT_NUM_CHUNKS,
|
||||
"display_priority": UNIFIED_ASSISTANT_DISPLAY_PRIORITY,
|
||||
"llm_filter_extraction": UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION,
|
||||
"llm_relevance_filter": UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER,
|
||||
"recency_bias": UNIFIED_ASSISTANT_RECENCY_BIAS,
|
||||
"chunks_above": UNIFIED_ASSISTANT_CHUNKS_ABOVE,
|
||||
"chunks_below": UNIFIED_ASSISTANT_CHUNKS_BELOW,
|
||||
"datetime_aware": UNIFIED_ASSISTANT_DATETIME_AWARE,
|
||||
}
|
||||
|
||||
GENERAL_ASSISTANT_ID = -1
|
||||
ART_ASSISTANT_ID = -3
|
||||
|
||||
|
||||
class UserRow(NamedTuple):
|
||||
"""Typed representation of user row from database query."""
|
||||
|
||||
id: UUID
|
||||
chosen_assistants: list[int] | None
|
||||
visible_assistants: list[int] | None
|
||||
hidden_assistants: list[int] | None
|
||||
pinned_assistants: list[int] | None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
@@ -0,0 +1,38 @@
|
||||
"""drop include citations
|
||||
|
||||
Revision ID: 8818cf73fa1a
|
||||
Revises: 7ed603b64d5a
|
||||
Create Date: 2025-09-02 19:43:50.060680
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8818cf73fa1a"
|
||||
down_revision = "7ed603b64d5a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("prompt", "include_citations")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"prompt",
|
||||
sa.Column(
|
||||
"include_citations",
|
||||
sa.BOOLEAN(),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
# Set include_citations based on prompt name: FALSE for ImageGeneration, TRUE for others
|
||||
op.execute(
|
||||
sa.text(
|
||||
"UPDATE prompt SET include_citations = CASE WHEN name = 'ImageGeneration' THEN FALSE ELSE TRUE END"
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,225 @@
|
||||
"""merge prompt into persona
|
||||
|
||||
Revision ID: abbfec3a5ac5
|
||||
Revises: 8818cf73fa1a
|
||||
Create Date: 2024-12-19 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abbfec3a5ac5"
|
||||
down_revision = "8818cf73fa1a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""NOTE: Prompts without any Personas will just be lost."""
|
||||
# Step 1: Add new columns to persona table (only if they don't exist)
|
||||
|
||||
# Check if columns exist before adding them
|
||||
connection = op.get_bind()
|
||||
inspector = sa.inspect(connection)
|
||||
existing_columns = [col["name"] for col in inspector.get_columns("persona")]
|
||||
|
||||
if "system_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "task_prompt" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
if "datetime_aware" not in existing_columns:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
|
||||
# Step 2: Migrate data from prompt table to persona table (only if tables exist)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if "prompt" in existing_tables and "persona__prompt" in existing_tables:
|
||||
# For personas that have associated prompts, copy the prompt data
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = p.system_prompt,
|
||||
task_prompt = p.task_prompt,
|
||||
datetime_aware = p.datetime_aware
|
||||
FROM (
|
||||
-- Get the first prompt for each persona (in case there are multiple)
|
||||
SELECT DISTINCT ON (pp.persona_id)
|
||||
pp.persona_id,
|
||||
pr.system_prompt,
|
||||
pr.task_prompt,
|
||||
pr.datetime_aware
|
||||
FROM persona__prompt pp
|
||||
JOIN prompt pr ON pp.prompt_id = pr.id
|
||||
) p
|
||||
WHERE persona.id = p.persona_id
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Update chat_message references
|
||||
# Since chat messages referenced prompt_id, we need to update them to use persona_id
|
||||
# This is complex as we need to map from prompt_id to persona_id
|
||||
|
||||
# Check if chat_message has prompt_id column
|
||||
chat_message_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_message")
|
||||
]
|
||||
if "prompt_id" in chat_message_columns:
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
DROP CONSTRAINT IF EXISTS chat_message__prompt_fk
|
||||
"""
|
||||
)
|
||||
op.drop_column("chat_message", "prompt_id")
|
||||
|
||||
# Step 4: Handle personas without prompts - set default values if needed (always run this)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET
|
||||
system_prompt = COALESCE(system_prompt, ''),
|
||||
task_prompt = COALESCE(task_prompt, '')
|
||||
WHERE system_prompt IS NULL OR task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Drop the persona__prompt association table (if it exists)
|
||||
if "persona__prompt" in existing_tables:
|
||||
op.drop_table("persona__prompt")
|
||||
|
||||
# Step 6: Drop the prompt table (if it exists)
|
||||
if "prompt" in existing_tables:
|
||||
op.drop_table("prompt")
|
||||
|
||||
# Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
nullable=False,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Recreate the prompt table
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=False),
|
||||
sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False),
|
||||
sa.Column(
|
||||
"datetime_aware", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
sa.Column(
|
||||
"default_prompt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Step 2: Recreate the persona__prompt association table
|
||||
op.create_table(
|
||||
"persona__prompt",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("prompt_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["prompt_id"],
|
||||
["prompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
|
||||
)
|
||||
|
||||
# Step 3: Migrate data back from persona to prompt table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO prompt (
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
default_prompt,
|
||||
deleted,
|
||||
user_id
|
||||
)
|
||||
SELECT
|
||||
CONCAT('Prompt for ', name),
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
datetime_aware,
|
||||
is_default_persona,
|
||||
deleted,
|
||||
user_id
|
||||
FROM persona
|
||||
WHERE system_prompt IS NOT NULL AND system_prompt != ''
|
||||
RETURNING id, name
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 4: Re-establish persona__prompt relationships
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__prompt (persona_id, prompt_id)
|
||||
SELECT
|
||||
p.id as persona_id,
|
||||
pr.id as prompt_id
|
||||
FROM persona p
|
||||
JOIN prompt pr ON pr.name = CONCAT('Prompt for ', p.name)
|
||||
WHERE p.system_prompt IS NOT NULL AND p.system_prompt != ''
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 5: Add prompt_id column back to chat_message
|
||||
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
|
||||
|
||||
# Step 6: Re-establish foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
|
||||
)
|
||||
|
||||
# Step 7: Remove columns from persona table
|
||||
op.drop_column("persona", "datetime_aware")
|
||||
op.drop_column("persona", "task_prompt")
|
||||
op.drop_column("persona", "system_prompt")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""adjust prompt length
|
||||
|
||||
Revision ID: b7ec9b5b505f
|
||||
Revises: abbfec3a5ac5
|
||||
Create Date: 2025-09-10 18:51:15.629197
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7ec9b5b505f"
|
||||
down_revision = "abbfec3a5ac5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MAX_PROMPT_LENGTH = 5_000_000
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=8000),
|
||||
type_=sa.String(length=MAX_PROMPT_LENGTH),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Downgrade not necessary
|
||||
pass
|
||||
125
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
125
backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""seed_builtin_tools
|
||||
|
||||
Revision ID: d09fc20a3c66
|
||||
Revises: b7ec9b5b505f
|
||||
Create Date: 2025-09-09 19:32:16.824373
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d09fc20a3c66"
|
||||
down_revision = "b7ec9b5b505f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Tool definitions - core tools that should always be seeded
|
||||
# Names/in_code_tool_id are the same as the class names in the tool_implementations package
|
||||
BUILT_IN_TOOLS = [
|
||||
{
|
||||
"name": "SearchTool",
|
||||
"display_name": "Internal Search",
|
||||
"description": "The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
"in_code_tool_id": "SearchTool",
|
||||
},
|
||||
{
|
||||
"name": "ImageGenerationTool",
|
||||
"display_name": "Image Generation",
|
||||
"description": (
|
||||
"The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
"in_code_tool_id": "ImageGenerationTool",
|
||||
},
|
||||
{
|
||||
"name": "WebSearchTool",
|
||||
"display_name": "Web Search",
|
||||
"description": (
|
||||
"The Web Search Action allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"in_code_tool_id": "WebSearchTool",
|
||||
},
|
||||
{
|
||||
"name": "KnowledgeGraphTool",
|
||||
"display_name": "Knowledge Graph Search",
|
||||
"description": (
|
||||
"The Knowledge Graph Search Action allows the assistant to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Assistant, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"in_code_tool_id": "KnowledgeGraphTool",
|
||||
},
|
||||
{
|
||||
"name": "OktaProfileTool",
|
||||
"display_name": "Okta Profile",
|
||||
"description": (
|
||||
"The Okta Profile Action allows the assistant to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
"in_code_tool_id": "OktaProfileTool",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
if tool["in_code_tool_id"] in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tools on downgrade since it's totally fine to just
|
||||
# have them around. If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
@@ -58,11 +58,11 @@ def downgrade() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
# Restore the foreign key constraint pointing to research_agent_iteration.id
|
||||
# Restore the foreign key constraint pointing to research_agent_iteration_sub_step.id
|
||||
op.create_foreign_key(
|
||||
"research_agent_iteration_sub_step_parent_question_id_fkey",
|
||||
"research_agent_iteration_sub_step",
|
||||
"research_agent_iteration",
|
||||
"research_agent_iteration_sub_step",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
|
||||
@@ -1,133 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name=OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
period=None,
|
||||
)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
@@ -135,5 +6,7 @@ celery_app.autodiscover_tasks(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ ee_beat_system_tasks: list[dict] = []
|
||||
ee_beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
@@ -57,7 +57,7 @@ if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
|
||||
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
106
backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def generate_usage_report_task(
|
||||
self: Task,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str | None = None,
|
||||
period_from: str | None = None,
|
||||
period_to: str | None = None,
|
||||
) -> None:
|
||||
"""User-initiated usage report generation task"""
|
||||
# Parse period if provided
|
||||
period = None
|
||||
if period_from and period_to:
|
||||
period = (
|
||||
datetime.fromisoformat(period_from),
|
||||
datetime.fromisoformat(period_to),
|
||||
)
|
||||
|
||||
# Generate the report
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
period=period,
|
||||
)
|
||||
@@ -1,38 +0,0 @@
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_answer_api(
|
||||
packets: AnswerStream,
|
||||
) -> OneShotQAResponse:
|
||||
response = OneShotQAResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
# Extraneous, provided for backwards compatibility
|
||||
response.rephrase = packet.rephrased_query
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
@@ -14,7 +14,6 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_sessions
|
||||
from onyx.db.chat import get_chat_sessions_by_slack_thread_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import StandardAnswer as StandardAnswerModel
|
||||
from onyx.onyxbot.slack.blocks import get_restate_blocks
|
||||
@@ -81,7 +80,6 @@ def _handle_standard_answers(
|
||||
message_info: SlackMessageInfo,
|
||||
receiver_ids: list[str] | None,
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
prompt: Prompt | None,
|
||||
logger: OnyxLoggingAdapter,
|
||||
client: WebClient,
|
||||
db_session: Session,
|
||||
@@ -161,7 +159,6 @@ def _handle_standard_answers(
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
message_type=MessageType.USER,
|
||||
@@ -182,7 +179,6 @@ def _handle_standard_answers(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=answer_message,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
|
||||
@@ -93,7 +93,6 @@ def handle_simplified_chat_message(
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
prompt_id=None,
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
@@ -181,7 +180,6 @@ def handle_send_message_simple_with_history(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
prompt_id=req.prompt_id,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
@@ -214,7 +212,6 @@ def handle_send_message_simple_with_history(
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
prompt_id=req.prompt_id,
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
|
||||
@@ -73,7 +73,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
@@ -162,7 +161,6 @@ class OneShotQARequest(ChunkContext):
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
|
||||
@@ -181,11 +179,9 @@ class OneShotQARequest(ChunkContext):
|
||||
def check_persona_fields(self) -> "OneShotQARequest":
|
||||
if self.persona_override_config is None and self.persona_id is None:
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
elif self.persona_override_config is not None and (
|
||||
self.persona_id is not None or self.prompt_id is not None
|
||||
):
|
||||
elif self.persona_override_config is not None and (self.persona_id is not None):
|
||||
raise ValueError(
|
||||
"If persona_override_config is set, persona_id and prompt_id cannot be set"
|
||||
"If persona_override_config is set, persona_id cannot be set"
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -196,6 +192,5 @@ class OneShotQAResponse(BaseModel):
|
||||
rephrase: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
|
||||
@@ -8,7 +8,6 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.chat.process_message import gather_stream_for_answer_api
|
||||
from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
oneoff_standard_answers,
|
||||
)
|
||||
@@ -22,6 +21,8 @@ from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from onyx.context.search.models import SavedSearchDocWithContent
|
||||
@@ -30,7 +31,6 @@ from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import get_prompt_by_id
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
@@ -39,6 +39,7 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -150,14 +151,6 @@ def get_answer_stream(
|
||||
):
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
prompt = None
|
||||
if query_request.prompt_id is not None:
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_request.prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona_info: Persona | PersonaOverrideConfig | None = None
|
||||
if query_request.persona_override_config is not None:
|
||||
persona_info = query_request.persona_override_config
|
||||
@@ -192,7 +185,6 @@ def get_answer_stream(
|
||||
user=user,
|
||||
persona_id=query_request.persona_id,
|
||||
persona_override_config=query_request.persona_override_config,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=query_request.retrieval_options,
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
@@ -218,12 +210,28 @@ def get_answer_with_citation(
|
||||
) -> OneShotQAResponse:
|
||||
try:
|
||||
packets = get_answer_stream(request, user, db_session)
|
||||
answer = gather_stream_for_answer_api(packets)
|
||||
answer = gather_stream(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
return OneShotQAResponse(
|
||||
answer=answer.answer,
|
||||
chat_message_id=answer.message_id,
|
||||
error_msg=answer.error_msg,
|
||||
citations=[
|
||||
CitationInfo(citation_num=i, document_id=doc_id)
|
||||
for i, doc_id in answer.cited_documents.items()
|
||||
],
|
||||
docs=QADocsResponse(
|
||||
top_documents=answer.top_documents,
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
applied_source_filters=None,
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=0.0,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_answer_with_citation: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="An internal server error occurred")
|
||||
|
||||
@@ -12,11 +12,13 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.usage_export import get_all_usage_reports
|
||||
from ee.onyx.db.usage_export import get_usage_report_data
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -26,24 +28,31 @@ class GenerateUsageReportParams(BaseModel):
|
||||
period_to: str | None = None
|
||||
|
||||
|
||||
@router.post("/admin/generate-usage-report")
|
||||
@router.post("/admin/usage-report", status_code=204)
|
||||
def generate_report(
|
||||
params: GenerateUsageReportParams,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UsageReportMetadata:
|
||||
period = None
|
||||
) -> None:
|
||||
# Validate period parameters
|
||||
if params.period_from and params.period_to:
|
||||
try:
|
||||
period = (
|
||||
datetime.fromisoformat(params.period_from),
|
||||
datetime.fromisoformat(params.period_to),
|
||||
)
|
||||
datetime.fromisoformat(params.period_from)
|
||||
datetime.fromisoformat(params.period_to)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
new_report = create_new_usage_report(db_session, user.id if user else None, period)
|
||||
return new_report
|
||||
tenant_id = get_current_tenant_id()
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": str(user.id) if user else None,
|
||||
"period_from": params.period_from,
|
||||
"period_to": params.period_to,
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/admin/usage-report/{report_name}")
|
||||
@@ -54,7 +63,7 @@ def read_usage_report(
|
||||
) -> Response:
|
||||
try:
|
||||
file = get_usage_report_data(report_name)
|
||||
except ValueError as e:
|
||||
except (ValueError, RuntimeError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
def iterfile() -> Generator[bytes, None, None]:
|
||||
|
||||
@@ -131,32 +131,35 @@ def _seed_llms(
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
if not persona.prompt_ids:
|
||||
raise ValueError(
|
||||
f"Invalid Persona with name {persona.name}; no prompts exist"
|
||||
try:
|
||||
for persona in personas:
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
tool_ids=persona.tool_ids,
|
||||
display_priority=persona.display_priority,
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompt_ids=persona.prompt_ids,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
is_public=persona.is_public,
|
||||
db_session=db_session,
|
||||
tool_ids=persona.tool_ids,
|
||||
display_priority=persona.display_priority,
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to seed personas.")
|
||||
raise
|
||||
|
||||
|
||||
def _seed_settings(settings: Settings) -> None:
|
||||
|
||||
@@ -39,7 +39,7 @@ def search_objects(
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_1_instructions = extract_section(
|
||||
instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
|
||||
@@ -43,7 +43,7 @@ def research_object_source(
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_2_instructions = extract_section(
|
||||
instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
|
||||
@@ -33,7 +33,7 @@ def consolidate_object_research(
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_4_instructions = extract_section(
|
||||
instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
|
||||
@@ -35,7 +35,7 @@ def consolidate_research(
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
# Populate prompt
|
||||
instructions = graph_config.inputs.persona.prompts[0].system_prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
try:
|
||||
agent_5_instructions = extract_section(
|
||||
|
||||
@@ -36,7 +36,7 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
next_tool_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.INTERNET_SEARCH,
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.INTERNET_SEARCH: 1.5,
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
@@ -26,5 +26,6 @@ AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 6.0,
|
||||
ResearchType.DEEP: 12.0,
|
||||
ResearchType.FAST: 0.5,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ class ResearchType(str, Enum):
|
||||
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
FAST = "FAST"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
@@ -20,10 +21,10 @@ class ResearchAnswerPurpose(str, Enum):
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Search Tool"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
INTERNET_SEARCH = "Internet Search"
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
CLOSER = "Closer"
|
||||
|
||||
@@ -23,12 +23,12 @@ from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_int
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_graph_builder import (
|
||||
dr_is_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
@@ -52,8 +52,8 @@ def dr_graph_builder() -> StateGraph:
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_is_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNET_SEARCH, internet_search_graph)
|
||||
internet_search_graph = dr_ws_graph_builder().compile()
|
||||
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
@@ -77,7 +77,7 @@ def dr_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.INTERNET_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
@@ -36,9 +35,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Tool
|
||||
@@ -54,8 +56,6 @@ from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
@@ -65,13 +65,13 @@ from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -108,19 +108,24 @@ def _get_available_tools(
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
if not tool.is_available(db_session):
|
||||
logger.info(f"Tool {tool.name} is not available, skipping")
|
||||
continue
|
||||
|
||||
tool_db_info = tool_dict.get(tool.id)
|
||||
if tool_db_info:
|
||||
incode_tool_id = tool_db_info.in_code_tool_id
|
||||
else:
|
||||
raise ValueError(f"Tool {tool.name} is not found in the database")
|
||||
|
||||
if isinstance(tool, InternetSearchTool):
|
||||
llm_path = DRPath.INTERNET_SEARCH.value
|
||||
path = DRPath.INTERNET_SEARCH
|
||||
if isinstance(tool, WebSearchTool):
|
||||
llm_path = DRPath.WEB_SEARCH.value
|
||||
path = DRPath.WEB_SEARCH
|
||||
elif isinstance(tool, SearchTool):
|
||||
llm_path = DRPath.INTERNAL_SEARCH.value
|
||||
path = DRPath.INTERNAL_SEARCH
|
||||
elif isinstance(tool, KnowledgeGraphTool) and include_kg:
|
||||
# TODO (chris): move this into the `is_available` check
|
||||
if len(active_source_types) == 0:
|
||||
logger.error(
|
||||
"No active source types found, skipping Knowledge Graph tool"
|
||||
@@ -310,7 +315,7 @@ _ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"name": "run_any_knowledge_retrieval_and_any_action_tool",
|
||||
"description": "Use this tool to get ANY external information \
|
||||
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
|
||||
ANY tool mentioned can be accessed through this generic tool.",
|
||||
ANY tool mentioned can be accessed through this generic tool. If in doubt, use this tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -399,15 +404,14 @@ def clarifier(
|
||||
else:
|
||||
active_source_type_descriptions_str = ""
|
||||
|
||||
if graph_config.inputs.persona and len(graph_config.inputs.persona.prompts) > 0:
|
||||
if graph_config.inputs.persona:
|
||||
assistant_system_prompt = (
|
||||
graph_config.inputs.persona.prompts[0].system_prompt
|
||||
or DEFAULT_DR_SYSTEM_PROMPT
|
||||
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
|
||||
) + "\n\n"
|
||||
if graph_config.inputs.persona.prompts[0].task_prompt:
|
||||
if graph_config.inputs.persona.task_prompt:
|
||||
assistant_task_prompt = (
|
||||
"\n\nHere are more specifications from the user:\n\n"
|
||||
+ graph_config.inputs.persona.prompts[0].task_prompt
|
||||
+ (graph_config.inputs.persona.task_prompt)
|
||||
)
|
||||
else:
|
||||
assistant_task_prompt = ""
|
||||
@@ -459,8 +463,9 @@ def clarifier(
|
||||
llm_decision = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
EVAL_SYSTEM_PROMPT_WO_TOOL_CALLING,
|
||||
assistant_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=DecisionResponse,
|
||||
)
|
||||
@@ -488,12 +493,13 @@ def clarifier(
|
||||
)
|
||||
|
||||
answer_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
answer_prompt + assistant_task_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
@@ -501,7 +507,7 @@ def clarifier(
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
ind=current_step_nr,
|
||||
context_docs=None,
|
||||
replace_citations=True,
|
||||
@@ -529,7 +535,7 @@ def clarifier(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=answer_str,
|
||||
update_parent_message=True,
|
||||
@@ -558,7 +564,7 @@ def clarifier(
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt + EVAL_SYSTEM_PROMPT_W_TOOL_CALLING,
|
||||
assistant_system_prompt,
|
||||
decision_prompt + assistant_task_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
@@ -586,11 +592,12 @@ def clarifier(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
@@ -611,7 +618,7 @@ def clarifier(
|
||||
|
||||
clarification = None
|
||||
|
||||
if research_type != ResearchType.THOUGHTFUL:
|
||||
if research_type == ResearchType.DEEP:
|
||||
result = _get_existing_clarification_request(graph_config)
|
||||
if result is not None:
|
||||
clarification, original_question, chat_history_string = result
|
||||
@@ -641,10 +648,12 @@ def clarifier(
|
||||
clarification_response = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, clarification_prompt
|
||||
assistant_system_prompt,
|
||||
clarification_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=ClarificationGenerationResponse,
|
||||
timeout_override=25,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -673,7 +682,7 @@ def clarifier(
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
80,
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_prompt,
|
||||
@@ -682,7 +691,7 @@ def clarifier(
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
@@ -706,7 +715,7 @@ def clarifier(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification_response.clarification_question,
|
||||
update_parent_message=True,
|
||||
@@ -734,7 +743,7 @@ def clarifier(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification.clarification_question,
|
||||
update_parent_message=True,
|
||||
|
||||
@@ -30,6 +30,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import DEFAULLT_DECISION_PROMPT
|
||||
@@ -139,6 +141,7 @@ def orchestrator(
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_image_context = state.uploaded_image_context or []
|
||||
|
||||
questions = [
|
||||
f"{iteration_response.tool}: {iteration_response.question}"
|
||||
@@ -170,11 +173,39 @@ def orchestrator(
|
||||
reasoning_result = "(No reasoning result provided yet.)"
|
||||
tool_calls_string = "(No tool calls provided yet.)"
|
||||
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
if iteration_nr == 1:
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.THOUGHTFUL]
|
||||
if research_type not in ResearchType:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
|
||||
elif iteration_nr > 1:
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
if iteration_nr == 1:
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[research_type]
|
||||
|
||||
elif remaining_time_budget <= 0:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
current_step_nr=current_step_nr,
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=None,
|
||||
reasoning="Time to wrap up.",
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
elif iteration_nr > 1 and remaining_time_budget > 0:
|
||||
# for each iteration past the first one, we need to see whether we
|
||||
# have enough information to answer the question.
|
||||
# if we do, we can stop the iteration and return the answer.
|
||||
@@ -200,18 +231,20 @@ def orchestrator(
|
||||
reasoning_tokens: list[str] = [""]
|
||||
|
||||
reasoning_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt, reasoning_prompt
|
||||
decision_system_prompt,
|
||||
reasoning_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
@@ -295,9 +328,10 @@ def orchestrator(
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=35,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=2500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
@@ -320,7 +354,7 @@ def orchestrator(
|
||||
reasoning_result = "Time to wrap up."
|
||||
next_tool_name = DRPath.CLOSER.value
|
||||
|
||||
else:
|
||||
elif research_type == ResearchType.DEEP:
|
||||
if iteration_nr == 1 and not plan_of_record:
|
||||
# by default, we start a new iteration, but if there is a feedback request,
|
||||
# we start a new iteration 0 again (set a bit later)
|
||||
@@ -346,9 +380,10 @@ def orchestrator(
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
plan_generation_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestrationPlan,
|
||||
timeout_override=25,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -368,7 +403,7 @@ def orchestrator(
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
80,
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_plan_prompt,
|
||||
@@ -377,7 +412,7 @@ def orchestrator(
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
),
|
||||
@@ -424,9 +459,10 @@ def orchestrator(
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=15,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
@@ -460,7 +496,7 @@ def orchestrator(
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
80,
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_reasoning_prompt,
|
||||
@@ -469,7 +505,7 @@ def orchestrator(
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
@@ -484,6 +520,9 @@ def orchestrator(
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Research type {research_type} is not implemented.")
|
||||
|
||||
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_PURPOSE,
|
||||
ResearchType.DEEP,
|
||||
@@ -498,48 +537,55 @@ def orchestrator(
|
||||
)
|
||||
|
||||
purpose_tokens: list[str] = [""]
|
||||
purpose = ""
|
||||
|
||||
try:
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.DEEP]:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
try:
|
||||
|
||||
purpose_tokens, _, _ = run_with_timeout(
|
||||
80,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
orchestration_next_step_purpose_prompt,
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
purpose_tokens, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
orchestration_next_step_purpose_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
current_step_nr += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in orchestration next step purpose: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error("Error in orchestration next step purpose.")
|
||||
raise e
|
||||
|
||||
purpose = cast(str, merge_content(*purpose_tokens))
|
||||
purpose = cast(str, merge_content(*purpose_tokens))
|
||||
|
||||
elif research_type == ResearchType.FAST:
|
||||
purpose = f"Answering the question using the {next_tool_name}"
|
||||
|
||||
if not next_tool_name:
|
||||
raise ValueError("The next step has not been defined. This should not happen.")
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_sear
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
@@ -34,12 +33,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
@@ -151,7 +153,7 @@ def save_iteration(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
@@ -239,14 +241,17 @@ def closer(
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
aggregated_context_w_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
iteration_responses_string = aggregated_context.context
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
aggregated_context_wo_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=False
|
||||
)
|
||||
|
||||
aggregated_context.is_internet_marker_dict
|
||||
iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
@@ -256,7 +261,7 @@ def closer(
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_string,
|
||||
questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
@@ -272,7 +277,7 @@ def closer(
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=40,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
|
||||
@@ -307,10 +312,35 @@ def closer(
|
||||
writer,
|
||||
)
|
||||
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
else:
|
||||
elif research_type == ResearchType.DEEP:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
else:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
|
||||
estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_w_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
)
|
||||
|
||||
# for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# Should generally not be the case though.
|
||||
|
||||
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
|
||||
if (
|
||||
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
iteration_responses_string = iteration_responses_wo_docs_string
|
||||
else:
|
||||
iteration_responses_string = iteration_responses_w_docs_string
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
@@ -326,7 +356,7 @@ def closer(
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
240,
|
||||
int(3 * TF_DR_TIMEOUT_LONG),
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
@@ -338,7 +368,7 @@ def closer(
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=60,
|
||||
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.dr.utils import update_db_session_with_messages
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -24,10 +23,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -93,6 +94,7 @@ def save_iteration(
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
@@ -132,7 +134,7 @@ def save_iteration(
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=str(graph_config.persistence.chat_session_id),
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
@@ -141,6 +143,7 @@ def save_iteration(
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
@@ -211,6 +214,14 @@ def logging(
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
final_answer = state.final_answer or ""
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
@@ -222,6 +233,7 @@ def logging(
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
|
||||
@@ -22,6 +22,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -94,7 +96,7 @@ def basic_search(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=15,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -150,6 +152,7 @@ def basic_search(
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
@@ -203,7 +206,7 @@ def basic_search(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
@@ -224,7 +227,9 @@ def basic_search(
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
if citation_numbers and max(citation_numbers) > len(retrieved_docs):
|
||||
if citation_numbers and (
|
||||
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
):
|
||||
raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
|
||||
cited_documents = {
|
||||
|
||||
@@ -24,7 +24,7 @@ logger = setup_logger()
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
LangGraph graph builder for Web Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
@@ -13,6 +13,8 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
@@ -68,7 +70,7 @@ def custom_tool_act(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=40,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
@@ -124,7 +126,7 @@ def custom_tool_act(
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
tool_summary_prompt, timeout_override=40
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
@@ -68,7 +69,7 @@ def generic_internal_tool_act(
|
||||
tool_use_prompt,
|
||||
tools=[generic_internal_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=40,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
@@ -113,7 +114,7 @@ def generic_internal_tool_act(
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke(
|
||||
tool_summary_prompt, timeout_override=40
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
|
||||
@@ -12,8 +12,13 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
@@ -65,7 +70,14 @@ def image_generation(
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
for tool_response in image_tool.run(prompt=branch_query):
|
||||
if tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolHeartbeat(),
|
||||
writer,
|
||||
)
|
||||
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
image_generation_responses = response
|
||||
break
|
||||
|
||||
@@ -24,7 +24,7 @@ logger = setup_logger()
|
||||
|
||||
def dr_image_generation_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
LangGraph graph builder for Image Generation Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
InternetContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
"""Truncate search result content to a maximum number of characters"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
return content[:max_chars] + "..."
|
||||
|
||||
|
||||
def _dummy_inference_section_from_internet_search_result(
|
||||
result: InternetContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = _truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content=truncated_content,
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
chunk_context=truncated_content,
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content=truncated_content,
|
||||
)
|
||||
|
||||
|
||||
def _dummy_internet_content_from_url(url: str) -> InternetContent:
|
||||
return InternetContent(
|
||||
title=url,
|
||||
link=url,
|
||||
full_content=url,
|
||||
)
|
||||
|
||||
|
||||
def web_fetch(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to fetch content from URLs and process the results.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
urls_to_open = state.urls_to_open
|
||||
|
||||
dummy_docs = [_dummy_internet_content_from_url(url) for url in urls_to_open]
|
||||
dummy_docs_inference_sections = [
|
||||
_dummy_inference_section_from_internet_search_result(doc) for doc in dummy_docs
|
||||
]
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
dummy_docs_inference_sections, is_internet=True
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_provider()
|
||||
if provider is None:
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
# Fetch content from URLs
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
try:
|
||||
retrieved_docs = [
|
||||
_dummy_inference_section_from_internet_search_result(result)
|
||||
for result in provider.contents(urls_to_open)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching URLs: {e}")
|
||||
|
||||
if not retrieved_docs:
|
||||
logger.warning("No content retrieved from URLs")
|
||||
|
||||
# Process documents and build context
|
||||
document_texts_list = []
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=state.branch_question or "",
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=40,
|
||||
)
|
||||
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
reasoning = ""
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=state.branch_question or "",
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="fetching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,60 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"search",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
current_step_nr=state.current_step_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def fetch_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
urls_to_process = state.urls_to_open
|
||||
|
||||
# If no URLs to process, return empty list to go directly to reducer
|
||||
if not urls_to_process:
|
||||
return []
|
||||
|
||||
url_pairs = [
|
||||
list(pair) for pair in zip(urls_to_process[::2], urls_to_process[1::2])
|
||||
]
|
||||
if len(urls_to_process) % 2 == 1:
|
||||
url_pairs.append([urls_to_process[-1]])
|
||||
|
||||
return [
|
||||
Send(
|
||||
"fetch",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
urls_to_open=url_pair,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, url_pair in enumerate(url_pairs)
|
||||
]
|
||||
@@ -1,63 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_2_search import (
|
||||
web_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_3_fetch import (
|
||||
web_fetch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_4_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.dr_is_conditional_edges import (
|
||||
fetch_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_is_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("search", web_search)
|
||||
|
||||
graph.add_node("fetch", web_fetch)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_conditional_edges("search", fetch_router)
|
||||
|
||||
# Fallback edge from search to reducer when no URLs are found
|
||||
graph.add_edge(start_key="search", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="fetch", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -14,14 +14,12 @@ class SubAgentUpdate(LoggerUpdate):
|
||||
|
||||
class BranchUpdate(LoggerUpdate):
|
||||
branch_iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
urls_to_open: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SubAgentInput(LoggerUpdate):
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
query_list: list[str] = []
|
||||
urls_to_open: Annotated[list[str], add] = []
|
||||
context: str | None = None
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
@@ -41,7 +39,7 @@ class SubAgentMainState(
|
||||
|
||||
class BranchInput(SubAgentInput):
|
||||
parallelization_nr: int = 0
|
||||
branch_question: str | None = None
|
||||
branch_question: str
|
||||
|
||||
|
||||
class CustomToolBranchInput(LoggerUpdate):
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
@@ -19,14 +19,14 @@ def is_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
LangGraph node to perform a web search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Internet Search {iteration_nr} at {datetime.now()}")
|
||||
logger.debug(f"Search start for Web Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
@@ -5,16 +5,19 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.providers import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@@ -22,7 +25,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.prompts.dr_prompts import INTERNET_SEARCH_URL_SELECTION_PROMPT
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import WEB_SEARCH_URL_SELECTION_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -30,15 +34,15 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def web_search(
|
||||
state: BranchInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BranchUpdate:
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchUpdate:
|
||||
"""
|
||||
LangGraph node to perform internet search and decide which URLs to fetch.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
if not current_step_nr:
|
||||
@@ -49,11 +53,7 @@ def web_search(
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
search_query = state.branch_question
|
||||
if not search_query:
|
||||
raise ValueError("search_query is not set")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
@@ -97,7 +97,7 @@ def web_search(
|
||||
for i, result in enumerate(search_results)
|
||||
]
|
||||
)
|
||||
agent_decision_prompt = INTERNET_SEARCH_URL_SELECTION_PROMPT.build(
|
||||
agent_decision_prompt = WEB_SEARCH_URL_SELECTION_PROMPT.build(
|
||||
search_query=search_query,
|
||||
base_question=base_question,
|
||||
search_results_text=search_results_text,
|
||||
@@ -109,28 +109,15 @@ def web_search(
|
||||
agent_decision_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=WebSearchAnswer,
|
||||
timeout_override=30,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
urls_to_open = [
|
||||
search_results[i].link
|
||||
results_to_open = [
|
||||
(search_query, search_results[i])
|
||||
for i in agent_decision.urls_to_open_indices
|
||||
if i < len(search_results) and i >= 0
|
||||
]
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=search_query,
|
||||
answer="",
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
return InternetSearchUpdate(
|
||||
results_to_open=results_to_open,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
@@ -138,7 +125,4 @@ def web_search(
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
# TODO: Pass through IterationAnswer instead of BranchUpdate
|
||||
# There's some tricky langgraph magic needed to make this work
|
||||
urls_to_open=urls_to_open,
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_search_result,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
|
||||
|
||||
def dedup_urls(
|
||||
state: InternetSearchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, InternetSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
unique_results_by_link[result.link] = result
|
||||
|
||||
unique_results = list(unique_results_by_link.values())
|
||||
dummy_docs_inference_sections = [
|
||||
dummy_inference_section_from_internet_search_result(doc)
|
||||
for doc in unique_results
|
||||
]
|
||||
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
dummy_docs_inference_sections, is_internet=True
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return InternetSearchInput(
|
||||
results_to_open=[],
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def web_fetch(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> FetchUpdate:
|
||||
"""
|
||||
LangGraph node to fetch content from URLs and process the results.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if graph_config.inputs.persona is None:
|
||||
raise ValueError("persona is not set")
|
||||
|
||||
provider = get_default_provider()
|
||||
if provider is None:
|
||||
raise ValueError("No web search provider found")
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
try:
|
||||
retrieved_docs = [
|
||||
dummy_inference_section_from_internet_content(result)
|
||||
for result in provider.contents(state.urls_to_open)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching URLs: {e}")
|
||||
|
||||
if not retrieved_docs:
|
||||
logger.warning("No content retrieved from URLs")
|
||||
|
||||
return FetchUpdate(
|
||||
raw_documents=retrieved_docs,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="fetching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
|
||||
|
||||
def collect_raw_docs(
|
||||
state: FetchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
raw_documents = state.raw_documents
|
||||
|
||||
return InternetSearchInput(
|
||||
raw_documents=raw_documents,
|
||||
)
|
||||
@@ -0,0 +1,127 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_summarize(
|
||||
state: SummarizeInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a internet search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
# build branch iterations from fetch inputs
|
||||
url_to_raw_document: dict[str, InferenceSection] = {}
|
||||
for raw_document in state.raw_documents:
|
||||
url_to_raw_document[raw_document.center_chunk.semantic_identifier] = (
|
||||
raw_document
|
||||
)
|
||||
urls = state.branch_questions_to_urls[state.branch_question]
|
||||
current_iteration = state.iteration_nr
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
research_type = graph_config.behavior.research_type
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
is_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
document_texts = _create_document_texts(cited_raw_documents)
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=state.branch_question,
|
||||
base_question=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
document_text=document_texts,
|
||||
)
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning or ""
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
cited_documents = {
|
||||
citation_number: cited_raw_documents[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
reasoning = ""
|
||||
claims = []
|
||||
cited_raw_documents = [url_to_raw_document[url] for url in urls]
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(cited_raw_documents)
|
||||
}
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=is_tool_info.llm_path,
|
||||
tool_id=is_tool_info.tool_id,
|
||||
iteration_nr=current_iteration,
|
||||
parallelization_nr=0,
|
||||
question=state.branch_question,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="internet_search",
|
||||
node_name="summarizing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _create_document_texts(raw_documents: list[InferenceSection]) -> str:
|
||||
document_texts_list = []
|
||||
for doc_num, retrieved_doc in enumerate(raw_documents):
|
||||
if not isinstance(retrieved_doc, InferenceSection):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
return "\n\n".join(document_texts_list)
|
||||
@@ -0,0 +1,79 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import FetchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import SummarizeInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"search",
|
||||
InternetSearchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
current_step_nr=state.current_step_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
query_list=[query],
|
||||
branch_question=query,
|
||||
context="",
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
results_to_open=[],
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def fetch_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"fetch",
|
||||
FetchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
urls_to_open=[url],
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
raw_documents=state.raw_documents,
|
||||
),
|
||||
)
|
||||
for url in set(
|
||||
url for urls in branch_questions_to_urls.values() for url in urls
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def summarize_router(state: InternetSearchInput) -> list[Send | Hashable]:
|
||||
branch_questions_to_urls = state.branch_questions_to_urls
|
||||
return [
|
||||
Send(
|
||||
"summarize",
|
||||
SummarizeInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
raw_documents=state.raw_documents,
|
||||
branch_questions_to_urls=branch_questions_to_urls,
|
||||
branch_question=branch_question,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
current_step_nr=state.current_step_nr,
|
||||
),
|
||||
)
|
||||
for branch_question in branch_questions_to_urls.keys()
|
||||
]
|
||||
@@ -0,0 +1,84 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_1_branch import (
|
||||
is_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_2_search import (
|
||||
web_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_3_dedup_urls import (
|
||||
dedup_urls,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_4_fetch import (
|
||||
web_fetch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_5_collect_raw_docs import (
|
||||
collect_raw_docs,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_6_summarize import (
|
||||
is_summarize,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_7_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
fetch_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_conditional_edges import (
|
||||
summarize_router,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_ws_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Internet Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", is_branch)
|
||||
|
||||
graph.add_node("search", web_search)
|
||||
|
||||
graph.add_node("dedup_urls", dedup_urls)
|
||||
|
||||
graph.add_node("fetch", web_fetch)
|
||||
|
||||
graph.add_node("collect_raw_docs", collect_raw_docs)
|
||||
|
||||
graph.add_node("summarize", is_summarize)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="search", end_key="dedup_urls")
|
||||
|
||||
graph.add_conditional_edges("dedup_urls", fetch_router)
|
||||
|
||||
graph.add_edge(start_key="fetch", end_key="collect_raw_docs")
|
||||
|
||||
graph.add_conditional_edges("collect_raw_docs", summarize_router)
|
||||
|
||||
graph.add_edge(start_key="summarize", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,7 +1,7 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.clients.exa_client import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.internet_search.models import (
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
@@ -0,0 +1,37 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
urls_to_open: Annotated[list[str], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class FetchUpdate(LoggerUpdate):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
|
||||
|
||||
class SummarizeInput(SubAgentInput):
|
||||
raw_documents: Annotated[list[InferenceSection], add] = []
|
||||
branch_questions_to_urls: dict[str, list[str]]
|
||||
branch_question: str
|
||||
@@ -0,0 +1,77 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
"""Truncate search result content to a maximum number of characters"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
return content[:max_chars] + "..."
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: InternetContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content=truncated_content,
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
chunk_context=truncated_content,
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content=truncated_content,
|
||||
)
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: InternetSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
chunk_id=0,
|
||||
blurb=result.title,
|
||||
content="",
|
||||
source_links={0: result.link},
|
||||
section_continuation=False,
|
||||
document_id="INTERNET_SEARCH_DOC_" + result.link,
|
||||
source_type=DocumentSource.WEB,
|
||||
semantic_identifier=result.link,
|
||||
title=result.title,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=result.published_date,
|
||||
image_file_id=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
@@ -1,11 +1,9 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
@@ -13,14 +11,11 @@ from onyx.agents.agent_search.kb_search.graph_utils import build_document_contex
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,7 +73,7 @@ def aggregate_context(
|
||||
):
|
||||
|
||||
iteration_tool = iteration_response.tool
|
||||
is_internet = iteration_tool == InternetSearchTool._NAME
|
||||
is_internet = iteration_tool == WebSearchTool._NAME
|
||||
|
||||
for cited_doc in iteration_response.cited_documents.values():
|
||||
unrolled_inference_sections.append(cited_doc)
|
||||
@@ -185,11 +180,40 @@ def get_chat_history_string(chat_history: list[BaseMessage], max_messages: int)
|
||||
Get the chat history (up to max_messages) as a string.
|
||||
"""
|
||||
# get past max_messages USER, ASSISTANT message pairs
|
||||
|
||||
past_messages = chat_history[-max_messages * 2 :]
|
||||
return ("...\n" if len(chat_history) > len(past_messages) else "") + "\n".join(
|
||||
filtered_past_messages = copy.deepcopy(past_messages)
|
||||
|
||||
for past_message_number, past_message in enumerate(past_messages):
|
||||
|
||||
if isinstance(past_message.content, list):
|
||||
removal_indices = []
|
||||
for content_piece_number, content_piece in enumerate(past_message.content):
|
||||
if (
|
||||
isinstance(content_piece, dict)
|
||||
and content_piece.get("type") != "text"
|
||||
):
|
||||
removal_indices.append(content_piece_number)
|
||||
|
||||
# Only rebuild the content list if there are items to remove
|
||||
if removal_indices:
|
||||
filtered_past_messages[past_message_number].content = [
|
||||
content_piece
|
||||
for content_piece_number, content_piece in enumerate(
|
||||
past_message.content
|
||||
)
|
||||
if content_piece_number not in removal_indices
|
||||
]
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
return (
|
||||
"...\n" if len(chat_history) > len(filtered_past_messages) else ""
|
||||
) + "\n".join(
|
||||
("user" if isinstance(msg, HumanMessage) else "you")
|
||||
+ f": {str(msg.content).strip()}"
|
||||
for msg in past_messages
|
||||
for msg in filtered_past_messages
|
||||
)
|
||||
|
||||
|
||||
@@ -251,78 +275,3 @@ def convert_inference_sections_to_search_docs(
|
||||
for search_doc in search_docs
|
||||
]
|
||||
return retrieved_saved_search_docs
|
||||
|
||||
|
||||
def update_db_session_with_messages(
|
||||
db_session: Session,
|
||||
chat_message_id: int,
|
||||
chat_session_id: str,
|
||||
is_agentic: bool | None,
|
||||
message: str | None = None,
|
||||
message_type: str | None = None,
|
||||
token_count: int | None = None,
|
||||
rephrased_query: str | None = None,
|
||||
prompt_id: int | None = None,
|
||||
citations: dict[int, int] | None = None,
|
||||
error: str | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
research_type: str | None = None,
|
||||
research_plan: dict[str, str] | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
update_parent_message: bool = True,
|
||||
research_answer_purpose: ResearchAnswerPurpose | None = None,
|
||||
) -> None:
|
||||
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(
|
||||
ChatMessage.id == chat_message_id,
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise ValueError("Chat message with id not found") # should never happen
|
||||
|
||||
if message:
|
||||
chat_message.message = message
|
||||
if message_type:
|
||||
chat_message.message_type = MessageType(message_type)
|
||||
if token_count:
|
||||
chat_message.token_count = token_count
|
||||
if rephrased_query:
|
||||
chat_message.rephrased_query = rephrased_query
|
||||
if prompt_id:
|
||||
chat_message.prompt_id = prompt_id
|
||||
if citations:
|
||||
# Convert string keys to integers to match database field type
|
||||
chat_message.citations = {int(k): v for k, v in citations.items()}
|
||||
if error:
|
||||
chat_message.error = error
|
||||
if alternate_assistant_id:
|
||||
chat_message.alternate_assistant_id = alternate_assistant_id
|
||||
if overridden_model:
|
||||
chat_message.overridden_model = overridden_model
|
||||
if research_type:
|
||||
chat_message.research_type = ResearchType(research_type)
|
||||
if research_plan:
|
||||
chat_message.research_plan = research_plan
|
||||
if final_documents:
|
||||
chat_message.search_docs = final_documents
|
||||
if is_agentic:
|
||||
chat_message.is_agentic = is_agentic
|
||||
|
||||
if research_answer_purpose:
|
||||
chat_message.research_answer_purpose = research_answer_purpose
|
||||
|
||||
if update_parent_message:
|
||||
parent_chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == chat_message.parent_message)
|
||||
.first()
|
||||
)
|
||||
if parent_chat_message:
|
||||
parent_chat_message.latest_child_message = chat_message.id
|
||||
|
||||
return
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.kg_configs import KG_MAX_DEEP_SEARCH_RESULTS
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_MAX_TOKENS
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT
|
||||
@@ -34,6 +35,8 @@ from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_ERROR_FIX_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SQL_INSTRUCTIONS_ENTITY_PROMPT
|
||||
from onyx.prompts.kg_prompts import SQL_INSTRUCTIONS_RELATIONSHIP_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
@@ -287,6 +290,8 @@ def generate_simple_sql(
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
additional_sql_instruction_prompt = SQL_INSTRUCTIONS_ENTITY_PROMPT
|
||||
else:
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
|
||||
@@ -305,12 +310,12 @@ def generate_simple_sql(
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
# generate initial sql statement
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
)
|
||||
]
|
||||
additional_sql_instruction_prompt = SQL_INSTRUCTIONS_RELATIONSHIP_PROMPT
|
||||
|
||||
msg = create_question_prompt(
|
||||
additional_sql_instruction_prompt,
|
||||
simple_sql_prompt,
|
||||
)
|
||||
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
try:
|
||||
@@ -417,6 +422,7 @@ def generate_simple_sql(
|
||||
query_generation_error = str(e)
|
||||
logger.warning(f"Error executing SQL query: {e}, retrying...")
|
||||
|
||||
# TODO: exclude the case where the verification failed
|
||||
# fix sql and try one more time if sql query didn't work out
|
||||
# if the result is still empty after this, the kg probably doesn't have the answer,
|
||||
# so we update the strategy to simple and address this in the answer generation
|
||||
@@ -486,9 +492,19 @@ def generate_simple_sql(
|
||||
source_document_results = None
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
# check source document sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
source_documents_sql, rel_temp_view, ent_temp_view
|
||||
)
|
||||
try:
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
source_documents_sql, rel_temp_view, ent_temp_view
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in source document sql: {e}")
|
||||
# TODO: raise error on frontend
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
raise
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
|
||||
@@ -191,7 +191,7 @@ def get_test_config(
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=get_tool_by_name(SearchTool._NAME, db_session).id,
|
||||
@@ -273,19 +273,17 @@ def get_test_config(
|
||||
def get_persona_agent_prompt_expressions(
|
||||
persona: Persona | None,
|
||||
) -> PersonaPromptExpressions:
|
||||
if persona is None or len(persona.prompts) == 0:
|
||||
# TODO base_prompt should be None, but no time to properly fix
|
||||
if persona is None:
|
||||
return PersonaPromptExpressions(
|
||||
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_DEFAULT, base_prompt=""
|
||||
)
|
||||
|
||||
# Only a 1:1 mapping between personas and prompts currently
|
||||
prompt = persona.prompts[0]
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
# Prompts are now embedded directly on the Persona model
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
datetime_aware_system_prompt = handle_onyx_date_awareness(
|
||||
prompt_str=prompt_config.system_prompt,
|
||||
prompt_config=prompt_config,
|
||||
add_additional_info_if_no_tag=prompt.datetime_aware,
|
||||
add_additional_info_if_no_tag=persona.datetime_aware,
|
||||
)
|
||||
|
||||
return PersonaPromptExpressions(
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import DEV_LOGGING_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
@@ -421,6 +422,13 @@ def on_setup_logging(
|
||||
root_logger.addHandler(root_handler)
|
||||
|
||||
if logfile:
|
||||
# Truncate log file if DEV_LOGGING_ENABLED (for clean dev experience)
|
||||
if DEV_LOGGING_ENABLED and os.path.exists(logfile):
|
||||
try:
|
||||
open(logfile, "w").close() # Truncate the file
|
||||
except Exception:
|
||||
pass # Ignore errors, just proceed with normal logging
|
||||
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
log_format,
|
||||
@@ -444,6 +452,7 @@ def on_setup_logging(
|
||||
task_logger.addHandler(task_handler)
|
||||
|
||||
if logfile:
|
||||
# No need to truncate again, already done above for root logger
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_handler.addFilter(TenantContextFilter())
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.agent_configs import TF_DR_DEFAULT_FAST
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Persona
|
||||
@@ -110,6 +111,14 @@ class Answer:
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
)
|
||||
|
||||
if use_agentic_search:
|
||||
research_type = ResearchType.DEEP
|
||||
elif TF_DR_DEFAULT_FAST:
|
||||
research_type = ResearchType.FAST
|
||||
else:
|
||||
research_type = ResearchType.THOUGHTFUL
|
||||
|
||||
self.search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
@@ -117,9 +126,7 @@ class Answer:
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
|
||||
kg_config_settings=get_kg_config_settings(),
|
||||
research_type=(
|
||||
ResearchType.DEEP if use_agentic_search else ResearchType.THOUGHTFUL
|
||||
),
|
||||
research_type=research_type,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -32,10 +32,8 @@ from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.prompts import get_prompts_by_ids
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
@@ -59,7 +57,6 @@ def prepare_chat_message_request(
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
@@ -83,7 +80,6 @@ def prepare_chat_message_request(
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
@@ -389,21 +385,11 @@ def create_temporary_persona(
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import Persona
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -170,8 +170,8 @@ class PromptOverrideConfig(BaseModel):
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
@@ -186,7 +186,7 @@ class PersonaOverrideConfig(BaseModel):
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
@@ -268,11 +268,10 @@ class PromptConfig(BaseModel):
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
cls, model: "Persona", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
@@ -280,10 +279,9 @@ class PromptConfig(BaseModel):
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
system_prompt=override_system_prompt or model.system_prompt or "",
|
||||
task_prompt=override_task_prompt or model.task_prompt or "",
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@@ -88,12 +88,12 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from onyx.tools.tool_constructor import InternetSearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_constructor import WebSearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -159,12 +159,10 @@ def _get_force_search_settings(
|
||||
override_kwargs=search_tool_override_kwargs,
|
||||
)
|
||||
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
web_search_available = any(isinstance(tool, WebSearchTool) for tool in tools)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
if not web_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
# Currently, the internet search tool does not support query override
|
||||
@@ -199,9 +197,7 @@ def _get_force_search_settings(
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name=(
|
||||
SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
),
|
||||
tool_name=(SearchTool._NAME if search_tool_available else WebSearchTool._NAME),
|
||||
args=args,
|
||||
override_kwargs=None,
|
||||
)
|
||||
@@ -335,11 +331,8 @@ def stream_chat_message_objects(
|
||||
properties=None,
|
||||
)
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
# Note: prompt configuration is now embedded in the persona
|
||||
# No need for separate prompt_id handling
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
@@ -399,7 +392,6 @@ def stream_chat_message_objects(
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
prompt_id=prompt_id,
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
@@ -557,23 +549,15 @@ def stream_chat_message_objects(
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
# Apply prompt override on top of persona-embedded prompt
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
persona,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt or persona.prompts[0]
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
@@ -606,7 +590,7 @@ def stream_chat_message_objects(
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
internet_search_tool_config=WebSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -89,13 +87,12 @@ def compute_max_document_tokens(
|
||||
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
|
||||
# Use the persona directly since prompts are now embedded
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
prompt_config=PromptConfig.from_model(persona),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
actual_user_input=actual_user_input,
|
||||
)
|
||||
@@ -110,8 +107,8 @@ def build_citations_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
# Citations are always enabled
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt, prompt_config, add_additional_info_if_no_tag=True
|
||||
)
|
||||
|
||||
@@ -76,7 +76,6 @@ def parse_user_files(
|
||||
|
||||
# Calculate available tokens for documents based on prompt, user input, etc.
|
||||
available_tokens = compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=actual_user_input,
|
||||
)
|
||||
|
||||
@@ -379,4 +379,11 @@ AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
|
||||
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
)
|
||||
|
||||
# Parameters for the Thoughtful/Deep Research flows
|
||||
TF_DR_TIMEOUT_LONG = int(os.environ.get("TF_DR_TIMEOUT_LONG") or 120)
|
||||
TF_DR_TIMEOUT_SHORT = int(os.environ.get("TF_DR_TIMEOUT_SHORT") or 60)
|
||||
|
||||
|
||||
TF_DR_DEFAULT_FAST = (os.environ.get("TF_DR_DEFAULT_FAST") or "False").lower() == "true"
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -199,6 +199,7 @@ class DocumentSource(str, Enum):
|
||||
HIGHSPOT = "highspot"
|
||||
|
||||
IMAP = "imap"
|
||||
BITBUCKET = "bitbucket"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
@@ -494,7 +495,7 @@ class OnyxCeleryTask:
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
PERFORM_TTL_MANAGEMENT_TASK = "perform_ttl_management_task"
|
||||
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
@@ -541,6 +542,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.GITHUB: "github data (issues, PRs)",
|
||||
DocumentSource.GITBOOK: "gitbook data",
|
||||
DocumentSource.GITLAB: "gitlab data",
|
||||
DocumentSource.BITBUCKET: "bitbucket data",
|
||||
DocumentSource.GURU: "guru data",
|
||||
DocumentSource.BOOKSTACK: "bookstack data",
|
||||
DocumentSource.OUTLINE: "outline data",
|
||||
|
||||
0
backend/onyx/connectors/bitbucket/__init__.py
Normal file
0
backend/onyx/connectors/bitbucket/__init__.py
Normal file
345
backend/onyx/connectors/bitbucket/connector.py
Normal file
345
backend/onyx/connectors/bitbucket/connector.py
Normal file
@@ -0,0 +1,345 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.bitbucket.utils import build_auth_client
|
||||
from onyx.connectors.bitbucket.utils import list_repositories
|
||||
from onyx.connectors.bitbucket.utils import map_pr_to_document
|
||||
from onyx.connectors.bitbucket.utils import paginate
|
||||
from onyx.connectors.bitbucket.utils import PR_LIST_RESPONSE_FIELDS
|
||||
from onyx.connectors.bitbucket.utils import SLIM_PR_LIST_RESPONSE_FIELDS
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint state for resumable Bitbucket PR indexing.
|
||||
|
||||
Fields:
|
||||
repos_queue: Materialized list of repository slugs to process.
|
||||
current_repo_index: Index of the repository currently being processed.
|
||||
next_url: Bitbucket "next" URL for continuing pagination within the current repo.
|
||||
"""
|
||||
|
||||
repos_queue: list[str] = []
|
||||
current_repo_index: int = 0
|
||||
next_url: str | None = None
|
||||
|
||||
|
||||
class BitbucketConnector(
|
||||
CheckpointedConnector[BitbucketConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
):
|
||||
"""Connector for indexing Bitbucket Cloud pull requests.
|
||||
|
||||
Args:
|
||||
workspace: Bitbucket workspace ID.
|
||||
repositories: Comma-separated list of repository slugs to index.
|
||||
projects: Comma-separated list of project keys to index all repositories within.
|
||||
batch_size: Max number of documents to yield per batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
repositories: str | None = None,
|
||||
projects: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self._repositories = (
|
||||
[s.strip() for s in repositories.split(",") if s.strip()]
|
||||
if repositories
|
||||
else None
|
||||
)
|
||||
self._projects: list[str] | None = (
|
||||
[s.strip() for s in projects.split(",") if s.strip()] if projects else None
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.email: str | None = None
|
||||
self.api_token: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load API token-based credentials.
|
||||
|
||||
Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`.
|
||||
"""
|
||||
self.email = credentials.get("bitbucket_email")
|
||||
self.api_token = credentials.get("bitbucket_api_token")
|
||||
if not self.email or not self.api_token:
|
||||
raise ConnectorMissingCredentialError("Bitbucket")
|
||||
return None
|
||||
|
||||
def _client(self) -> httpx.Client:
|
||||
"""Build an authenticated HTTP client or raise if credentials missing."""
|
||||
if not self.email or not self.api_token:
|
||||
raise ConnectorMissingCredentialError("Bitbucket")
|
||||
return build_auth_client(self.email, self.api_token)
|
||||
|
||||
def _iter_pull_requests_for_repo(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
repo_slug: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
start_url: str | None = None,
|
||||
on_page: Callable[[str | None], None] | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests"
|
||||
yield from paginate(
|
||||
client,
|
||||
base,
|
||||
params,
|
||||
start_url=start_url,
|
||||
on_page=on_page,
|
||||
)
|
||||
|
||||
def _build_params(
|
||||
self,
|
||||
fields: str = PR_LIST_RESPONSE_FIELDS,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Bitbucket fetch params.
|
||||
|
||||
Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and
|
||||
``end`` are provided, apply a single updated_on time window.
|
||||
"""
|
||||
|
||||
def _iso(ts: SecondsSinceUnixEpoch) -> str:
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
|
||||
|
||||
def _tc_epoch(
|
||||
lower_epoch: SecondsSinceUnixEpoch | None,
|
||||
upper_epoch: SecondsSinceUnixEpoch | None,
|
||||
) -> str | None:
|
||||
if lower_epoch is not None and upper_epoch is not None:
|
||||
lower_iso = _iso(lower_epoch)
|
||||
upper_iso = _iso(upper_epoch)
|
||||
return f'(updated_on >= "{lower_iso}" AND updated_on <= "{upper_iso}")'
|
||||
return None
|
||||
|
||||
params: dict[str, Any] = {"fields": fields, "pagelen": 50}
|
||||
time_clause = _tc_epoch(start, end)
|
||||
q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")'
|
||||
if time_clause:
|
||||
q = f"{q} AND {time_clause}"
|
||||
params["q"] = q
|
||||
return params
|
||||
|
||||
def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]:
|
||||
"""Yield repository slugs based on configuration.
|
||||
|
||||
Priority:
|
||||
- repositories list
|
||||
- projects list (list repos by project key)
|
||||
- workspace (all repos)
|
||||
"""
|
||||
if self._repositories:
|
||||
for slug in self._repositories:
|
||||
yield slug
|
||||
return
|
||||
if self._projects:
|
||||
for project_key in self._projects:
|
||||
for repo in list_repositories(client, self.workspace, project_key):
|
||||
slug_val = repo.get("slug")
|
||||
if isinstance(slug_val, str) and slug_val:
|
||||
yield slug_val
|
||||
return
|
||||
for repo in list_repositories(client, self.workspace, None):
|
||||
slug_val = repo.get("slug")
|
||||
if isinstance(slug_val, str) and slug_val:
|
||||
yield slug_val
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: BitbucketConnectorCheckpoint,
|
||||
) -> CheckpointOutput[BitbucketConnectorCheckpoint]:
|
||||
"""Resumable PR ingestion across repos and pages within a time window.
|
||||
|
||||
Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns
|
||||
an updated checkpoint that records repo position and next page URL.
|
||||
"""
|
||||
new_checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
with self._client() as client:
|
||||
# Materialize target repositories once
|
||||
if not new_checkpoint.repos_queue:
|
||||
# Preserve explicit order; otherwise ensure deterministic ordering
|
||||
repos_list = list(self._iter_target_repositories(client))
|
||||
new_checkpoint.repos_queue = sorted(set(repos_list))
|
||||
new_checkpoint.current_repo_index = 0
|
||||
new_checkpoint.next_url = None
|
||||
|
||||
repos = new_checkpoint.repos_queue
|
||||
if not repos or new_checkpoint.current_repo_index >= len(repos):
|
||||
new_checkpoint.has_more = False
|
||||
return new_checkpoint
|
||||
|
||||
repo_slug = repos[new_checkpoint.current_repo_index]
|
||||
|
||||
first_page_params = self._build_params(
|
||||
fields=PR_LIST_RESPONSE_FIELDS,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _on_page(next_url: str | None) -> None:
|
||||
new_checkpoint.next_url = next_url
|
||||
|
||||
for pr in self._iter_pull_requests_for_repo(
|
||||
client,
|
||||
repo_slug,
|
||||
params=first_page_params,
|
||||
start_url=new_checkpoint.next_url,
|
||||
on_page=_on_page,
|
||||
):
|
||||
try:
|
||||
document = map_pr_to_document(pr, self.workspace, repo_slug)
|
||||
yield document
|
||||
except Exception as e:
|
||||
pr_id = pr.get("id")
|
||||
pr_link = (
|
||||
f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}"
|
||||
if pr_id is not None
|
||||
else None
|
||||
)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=(
|
||||
f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}"
|
||||
if pr_id is not None
|
||||
else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown"
|
||||
),
|
||||
document_link=pr_link,
|
||||
),
|
||||
failure_message=f"Failed to process Bitbucket PR: {e}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
# Advance to next repository (if any) and set has_more accordingly
|
||||
new_checkpoint.current_repo_index += 1
|
||||
new_checkpoint.next_url = None
|
||||
new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint:
|
||||
"""Create an initial checkpoint with work remaining."""
|
||||
return BitbucketConnectorCheckpoint(has_more=True)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> BitbucketConnectorCheckpoint:
|
||||
"""Validate and deserialize a checkpoint instance from JSON."""
|
||||
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Iterator[list[SlimDocument]]:
|
||||
"""Return only document IDs for all existing pull requests."""
|
||||
batch: list[SlimDocument] = []
|
||||
params = self._build_params(
|
||||
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
with self._client() as client:
|
||||
for slug in self._iter_target_repositories(client):
|
||||
for pr in self._iter_pull_requests_for_repo(
|
||||
client, slug, params=params
|
||||
):
|
||||
pr_id = pr["id"]
|
||||
doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}"
|
||||
batch.append(SlimDocument(id=doc_id))
|
||||
if len(batch) >= self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
# Note: this is not actually used for permission sync yet, just pruning
|
||||
raise RuntimeError(
|
||||
"bitbucket_pr_sync: Stop signal detected"
|
||||
)
|
||||
callback.progress("bitbucket_pr_sync", len(batch))
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Bitbucket credentials and workspace access by probing a lightweight endpoint.
|
||||
|
||||
Raises:
|
||||
CredentialExpiredError: on HTTP 401
|
||||
InsufficientPermissionsError: on HTTP 403
|
||||
UnexpectedValidationError: on any other failure
|
||||
"""
|
||||
try:
|
||||
with self._client() as client:
|
||||
url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}"
|
||||
resp = client.get(
|
||||
url,
|
||||
params={"pagelen": 1, "fields": "pagelen"},
|
||||
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||
)
|
||||
if resp.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Bitbucket credentials (HTTP 401)."
|
||||
)
|
||||
if resp.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Insufficient permissions to access Bitbucket workspace (HTTP 403)."
|
||||
)
|
||||
if resp.status_code < 200 or resp.status_code >= 300:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Bitbucket error (status={resp.status_code})."
|
||||
)
|
||||
except Exception as e:
|
||||
# Network or other unexpected errors
|
||||
if isinstance(
|
||||
e,
|
||||
(
|
||||
CredentialExpiredError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError,
|
||||
ConnectorMissingCredentialError,
|
||||
),
|
||||
):
|
||||
raise
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error while validating Bitbucket settings: {e}"
|
||||
)
|
||||
294
backend/onyx/connectors/bitbucket/utils.py
Normal file
294
backend/onyx/connectors/bitbucket/utils.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Fields requested from Bitbucket PR list endpoint to ensure rich PR data
|
||||
PR_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||
[
|
||||
"next",
|
||||
"page",
|
||||
"pagelen",
|
||||
"values.author",
|
||||
"values.close_source_branch",
|
||||
"values.closed_by",
|
||||
"values.comment_count",
|
||||
"values.created_on",
|
||||
"values.description",
|
||||
"values.destination",
|
||||
"values.draft",
|
||||
"values.id",
|
||||
"values.links",
|
||||
"values.merge_commit",
|
||||
"values.participants",
|
||||
"values.reason",
|
||||
"values.rendered",
|
||||
"values.reviewers",
|
||||
"values.source",
|
||||
"values.state",
|
||||
"values.summary",
|
||||
"values.task_count",
|
||||
"values.title",
|
||||
"values.type",
|
||||
"values.updated_on",
|
||||
]
|
||||
)
|
||||
|
||||
# Minimal fields for slim retrieval (IDs only)
|
||||
SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||
[
|
||||
"next",
|
||||
"page",
|
||||
"pagelen",
|
||||
"values.id",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Minimal fields for repository list calls
|
||||
REPO_LIST_RESPONSE_FIELDS: str = ",".join(
|
||||
[
|
||||
"next",
|
||||
"page",
|
||||
"pagelen",
|
||||
"values.slug",
|
||||
"values.full_name",
|
||||
"values.project.key",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BitbucketRetriableError(Exception):
|
||||
"""Raised for retriable Bitbucket conditions (429, 5xx)."""
|
||||
|
||||
|
||||
class BitbucketNonRetriableError(Exception):
|
||||
"""Raised for non-retriable Bitbucket client errors (4xx except 429)."""
|
||||
|
||||
|
||||
@retry_builder(
|
||||
tries=6,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=30,
|
||||
exceptions=(BitbucketRetriableError, httpx.RequestError),
|
||||
)
|
||||
@rate_limit_builder(max_calls=60, period=60)
|
||||
def bitbucket_get(
|
||||
client: httpx.Client, url: str, params: dict[str, Any] | None = None
|
||||
) -> httpx.Response:
|
||||
"""Perform a GET against Bitbucket with retry and rate limiting.
|
||||
|
||||
Retries on 429 and 5xx responses, and on transport errors. Honors
|
||||
`Retry-After` header for 429 when present by sleeping before retrying.
|
||||
"""
|
||||
try:
|
||||
response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS)
|
||||
except httpx.RequestError:
|
||||
# Allow retry_builder to handle retries of transport errors
|
||||
raise
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
if status == 429:
|
||||
retry_after = e.response.headers.get("Retry-After") if e.response else None
|
||||
if retry_after is not None:
|
||||
try:
|
||||
time.sleep(int(retry_after))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e
|
||||
if status is not None and 500 <= status < 600:
|
||||
raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e
|
||||
if status is not None and 400 <= status < 500:
|
||||
raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e
|
||||
# Unknown status, propagate
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def build_auth_client(email: str, api_token: str) -> httpx.Client:
|
||||
"""Create an authenticated httpx client for Bitbucket Cloud API."""
|
||||
return httpx.Client(auth=(email, api_token), http2=True)
|
||||
|
||||
|
||||
def paginate(
|
||||
client: httpx.Client,
|
||||
url: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
start_url: str | None = None,
|
||||
on_page: Callable[[str | None], None] | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""Iterate over paginated Bitbucket API responses yielding individual values.
|
||||
|
||||
Args:
|
||||
client: Authenticated HTTP client.
|
||||
url: Base collection URL (first page when start_url is None).
|
||||
params: Query params for the first page.
|
||||
start_url: If provided, start from this absolute URL (ignores params).
|
||||
on_page: Optional callback invoked after each page with the next page URL.
|
||||
"""
|
||||
next_url = start_url or url
|
||||
# If resuming from a next URL, do not pass params again
|
||||
query = params.copy() if params else None
|
||||
query = None if start_url else query
|
||||
while next_url:
|
||||
resp = bitbucket_get(client, next_url, params=query)
|
||||
data = resp.json()
|
||||
values = data.get("values", [])
|
||||
for item in values:
|
||||
yield item
|
||||
next_url = data.get("next")
|
||||
if on_page is not None:
|
||||
on_page(next_url)
|
||||
# only include params on first call, next_url will contain all necessary params
|
||||
query = None
|
||||
|
||||
|
||||
def list_repositories(
|
||||
client: httpx.Client, workspace: str, project_key: str | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""List repositories in a workspace, optionally filtered by project key."""
|
||||
base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}"
|
||||
params: dict[str, Any] = {
|
||||
"fields": REPO_LIST_RESPONSE_FIELDS,
|
||||
"pagelen": 100,
|
||||
# Ensure deterministic ordering
|
||||
"sort": "full_name",
|
||||
}
|
||||
if project_key:
|
||||
params["q"] = f'project.key="{project_key}"'
|
||||
yield from paginate(client, base_url, params)
|
||||
|
||||
|
||||
def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document:
|
||||
"""Map a Bitbucket pull request JSON to Onyx Document."""
|
||||
pr_id = pr["id"]
|
||||
title = pr.get("title") or f"PR {pr_id}"
|
||||
description = pr.get("description") or ""
|
||||
state = pr.get("state")
|
||||
draft = pr.get("draft", False)
|
||||
author = pr.get("author", {})
|
||||
reviewers = pr.get("reviewers", [])
|
||||
participants = pr.get("participants", [])
|
||||
|
||||
link = pr.get("links", {}).get("html", {}).get("href") or (
|
||||
f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}"
|
||||
)
|
||||
|
||||
created_on = pr.get("created_on")
|
||||
updated_on = pr.get("updated_on")
|
||||
updated_dt = (
|
||||
datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
if isinstance(updated_on, str)
|
||||
else None
|
||||
)
|
||||
|
||||
source_branch = pr.get("source", {}).get("branch", {}).get("name", "")
|
||||
destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "")
|
||||
|
||||
approved_by = [
|
||||
_get_user_name(p.get("user", {})) for p in participants if p.get("approved")
|
||||
]
|
||||
|
||||
primary_owner = None
|
||||
if author:
|
||||
primary_owner = BasicExpertInfo(
|
||||
display_name=_get_user_name(author),
|
||||
)
|
||||
|
||||
secondary_owners = [
|
||||
BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers
|
||||
] or None
|
||||
|
||||
reviewer_names = [_get_user_name(r) for r in reviewers]
|
||||
|
||||
# Create a concise summary of key PR info
|
||||
created_date = created_on.split("T")[0] if created_on else "N/A"
|
||||
updated_date = updated_on.split("T")[0] if updated_on else "N/A"
|
||||
content_text = (
|
||||
"Pull Request Information:\n"
|
||||
f"- Pull Request ID: {pr_id}\n"
|
||||
f"- Title: {title}\n"
|
||||
f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n"
|
||||
)
|
||||
if state == "DECLINED":
|
||||
content_text += f"- Reason: {pr.get('reason', 'N/A')}\n"
|
||||
content_text += (
|
||||
f"- Author: {_get_user_name(author) if author else 'N/A'}\n"
|
||||
f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n"
|
||||
f"- Branch: {source_branch} -> {destination_branch}\n"
|
||||
f"- Created: {created_date}\n"
|
||||
f"- Updated: {updated_date}"
|
||||
)
|
||||
if description:
|
||||
content_text += f"\n\nDescription:\n{description}"
|
||||
sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=content_text)
|
||||
]
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"object_type": "PullRequest",
|
||||
"workspace": workspace,
|
||||
"repository": repo_slug,
|
||||
"pr_key": f"{workspace}/{repo_slug}#{pr_id}",
|
||||
"id": str(pr_id),
|
||||
"title": title,
|
||||
"state": state or "",
|
||||
"draft": str(bool(draft)),
|
||||
"link": link,
|
||||
"author": _get_user_name(author) if author else "",
|
||||
"reviewers": reviewer_names,
|
||||
"approved_by": approved_by,
|
||||
"comment_count": str(pr.get("comment_count", "")),
|
||||
"task_count": str(pr.get("task_count", "")),
|
||||
"created_on": created_on or "",
|
||||
"updated_on": updated_on or "",
|
||||
"source_branch": source_branch,
|
||||
"destination_branch": destination_branch,
|
||||
"closed_by": (
|
||||
_get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else ""
|
||||
),
|
||||
"close_source_branch": str(bool(pr.get("close_source_branch", False))),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.BITBUCKET,
|
||||
semantic_identifier=f"#{pr_id}: {title}",
|
||||
title=title,
|
||||
doc_updated_at=updated_dt,
|
||||
primary_owners=[primary_owner] if primary_owner else None,
|
||||
secondary_owners=secondary_owners,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _get_user_name(user: dict[str, Any]) -> str:
|
||||
return user.get("display_name") or user.get("nickname") or "unknown"
|
||||
@@ -87,7 +87,7 @@ def process_onyx_metadata(
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[OnyxMetadata, dict[str, Any]]:
|
||||
"""
|
||||
Users may set Onyx metadata and custom tags in text files. https://docs.onyx.app/connectors/file
|
||||
Users may set Onyx metadata and custom tags in text files. https://docs.onyx.app/admin/connectors/official/file
|
||||
Any unrecognized fields are treated as custom tags.
|
||||
"""
|
||||
p_owner_names = metadata.get("primary_owners")
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.bitbucket.connector import BitbucketConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
@@ -125,6 +126,7 @@ def identify_connector_class(
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
DocumentSource.BITBUCKET: BitbucketConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ def _process_file(
|
||||
content_type=file_type,
|
||||
)
|
||||
|
||||
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/connectors/file
|
||||
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/admin/connectors/official/file
|
||||
# If so, we should add it to any metadata processed so far
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
|
||||
@@ -5,9 +5,13 @@ from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -66,6 +70,14 @@ _STATUS_NUMBER_TYPE_MAP: dict[int, str] = {
|
||||
}
|
||||
|
||||
|
||||
# TODO: unify this with other generic rate limited requests with retries (e.g. Axero, Notion?)
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _rate_limited_freshdesk_get(
|
||||
url: str, auth: tuple, params: dict
|
||||
) -> requests.Response:
|
||||
return rl_requests.get(url, auth=auth, params=params)
|
||||
|
||||
|
||||
def _create_metadata_from_ticket(ticket: dict) -> dict:
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
# Combine all emails into a list so there are no repeated emails
|
||||
@@ -154,16 +166,32 @@ class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
def load_credentials(self, credentials: dict[str, str | int]) -> None:
|
||||
api_key = credentials.get("freshdesk_api_key")
|
||||
domain = credentials.get("freshdesk_domain")
|
||||
password = credentials.get("freshdesk_password")
|
||||
|
||||
if not all(isinstance(cred, str) for cred in [domain, api_key, password]):
|
||||
if not all(isinstance(cred, str) for cred in [domain, api_key]):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"All Freshdesk credentials must be strings"
|
||||
)
|
||||
|
||||
# TODO: Move the domain to the connector-specific configuration instead of part of the credential
|
||||
# Then apply normalization and validation against the config
|
||||
# Clean and normalize the domain URL
|
||||
domain = str(domain).strip().lower()
|
||||
|
||||
# Remove any trailing slashes
|
||||
domain = domain.rstrip("/")
|
||||
|
||||
# Remove protocol if present
|
||||
if domain.startswith(("http://", "https://")):
|
||||
domain = domain.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Remove .freshdesk.com suffix and any API paths if present
|
||||
if ".freshdesk.com" in domain:
|
||||
domain = domain.split(".freshdesk.com")[0]
|
||||
|
||||
if not domain:
|
||||
raise ConnectorMissingCredentialError("Freshdesk domain cannot be empty")
|
||||
|
||||
self.api_key = str(api_key)
|
||||
self.domain = str(domain)
|
||||
self.password = str(password)
|
||||
self.domain = domain
|
||||
|
||||
def _fetch_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
@@ -177,7 +205,7 @@ class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
'include' field available for this endpoint:
|
||||
https://developers.freshdesk.com/api/#filter_tickets
|
||||
"""
|
||||
if self.api_key is None or self.domain is None or self.password is None:
|
||||
if self.api_key is None or self.domain is None:
|
||||
raise ConnectorMissingCredentialError("freshdesk")
|
||||
|
||||
base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets"
|
||||
@@ -191,8 +219,11 @@ class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
params["updated_since"] = start.isoformat()
|
||||
|
||||
while True:
|
||||
response = requests.get(
|
||||
base_url, auth=(self.api_key, self.password), params=params
|
||||
# Freshdesk API uses API key as the username and any value as the password.
|
||||
response = _rate_limited_freshdesk_get(
|
||||
base_url,
|
||||
auth=(self.api_key, "CanYouBelieveFreshdeskDoesThis"),
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
# Documentation and error messages
|
||||
SCOPE_DOC_URL = "https://docs.onyx.app/connectors/google_drive/overview"
|
||||
SCOPE_DOC_URL = "https://docs.onyx.app/admin/connectors/official/google_drive/overview"
|
||||
ONYX_SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded Onyx without updating the Google Auth scopes. "
|
||||
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
|
||||
|
||||
@@ -423,7 +423,7 @@ def _sanitize_mailbox_names(mailboxes: list[str]) -> list[str]:
|
||||
|
||||
def _parse_addrs(raw_header: str) -> list[tuple[str, str]]:
|
||||
addrs = raw_header.split(",")
|
||||
name_addr_pairs = [parseaddr(addr=addr, strict=True) for addr in addrs if addr]
|
||||
name_addr_pairs = [parseaddr(addr=addr) for addr in addrs if addr]
|
||||
return [(name, addr) for name, addr in name_addr_pairs if addr]
|
||||
|
||||
|
||||
|
||||
@@ -35,10 +35,10 @@ class EmailHeaders(BaseModel):
|
||||
if not value:
|
||||
return None
|
||||
|
||||
decoded_value, _encoding = email.header.decode_header(value)[0]
|
||||
|
||||
decoded_value, encoding = email.header.decode_header(value)[0]
|
||||
if isinstance(decoded_value, bytes):
|
||||
return decoded_value.decode()
|
||||
encoding = encoding or "utf-8"
|
||||
return decoded_value.decode(encoding, errors="replace")
|
||||
elif isinstance(decoded_value, str):
|
||||
return decoded_value
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import copy
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -7,6 +10,7 @@ from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
from more_itertools import chunked
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -34,6 +38,7 @@ from onyx.connectors.jira.utils import build_jira_url
|
||||
from onyx.connectors.jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.jira.utils import get_comment_strs
|
||||
from onyx.connectors.jira.utils import get_jira_project_key_from_issue
|
||||
from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -49,6 +54,7 @@ logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
@@ -73,13 +79,155 @@ _FIELD_RESOLUTION_DATE = "resolutiondate"
|
||||
_FIELD_RESOLUTION_DATE_KEY = "resolution_date"
|
||||
|
||||
|
||||
def _is_cloud_client(jira_client: JIRA) -> bool:
|
||||
return jira_client._options["rest_api_version"] == JIRA_CLOUD_API_VERSION
|
||||
|
||||
|
||||
def _perform_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
all_issue_ids: list[list[str]] | None = None,
|
||||
checkpoint_callback: (
|
||||
Callable[[Iterator[list[str]], str | None], None] | None
|
||||
) = None,
|
||||
nextPageToken: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
"""
|
||||
The caller should expect
|
||||
a) this function returns an iterable of issues of length 0 < len(issues) <= max_results.
|
||||
- caveat; if all_issue_ids is provided, the iterable will be the size of some sub-list.
|
||||
- this will only not match the above bound if a recent deployment changed max_results.
|
||||
|
||||
IF the v3 API is used (i.e. the jira instance is a cloud instance), then the caller should expect:
|
||||
|
||||
b) this function will call checkpoint_callback ONCE after at least one of the following has happened:
|
||||
- a new batch of ids has been fetched via enhanced search
|
||||
- a batch of issues has been bulk-fetched
|
||||
c) checkpoint_callback is called with the new all_issue_ids and the pageToken of the enhanced
|
||||
search request. We pass in a pageToken of None once we've fetched all the issue ids.
|
||||
|
||||
Note: nextPageToken is valid for 7 days according to a post from a year ago, so for now
|
||||
we won't add any handling for restarting (just re-index, since there's no easy
|
||||
way to recover from this).
|
||||
"""
|
||||
# it would be preferable to use one approach for both versions, but
|
||||
# v2 doesnt have the bulk fetch api and v3 has fully deprecated the search
|
||||
# api that v2 uses
|
||||
if _is_cloud_client(jira_client):
|
||||
if all_issue_ids is None:
|
||||
raise ValueError("all_issue_ids is required for v3")
|
||||
return _perform_jql_search_v3(
|
||||
jira_client,
|
||||
jql,
|
||||
max_results,
|
||||
all_issue_ids,
|
||||
fields=fields,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=nextPageToken,
|
||||
ids_done=ids_done,
|
||||
)
|
||||
else:
|
||||
return _perform_jql_search_v2(jira_client, jql, start, max_results, fields)
|
||||
|
||||
|
||||
def enhanced_search_ids(
|
||||
jira_client: JIRA, jql: str, nextPageToken: str | None = None
|
||||
) -> tuple[list[str], str | None]:
|
||||
# https://community.atlassian.com/forums/Jira-articles/
|
||||
# Avoiding-Pitfalls-A-Guide-to-Smooth-Migration-to-Enhanced-JQL/ba-p/2985433
|
||||
# For cloud, it's recommended that we fetch all ids first then use the bulk fetch API.
|
||||
# The enhanced search isn't currently supported by our python library, so we have to
|
||||
# do this janky thing where we use the session directly.
|
||||
enhanced_search_path = jira_client._get_url("search/jql")
|
||||
params: dict[str, str | int | None] = {
|
||||
"jql": jql,
|
||||
"maxResults": _MAX_RESULTS_FETCH_IDS,
|
||||
"nextPageToken": nextPageToken,
|
||||
"fields": "id",
|
||||
}
|
||||
response = jira_client._session.get(enhanced_search_path, params=params).json()
|
||||
return [str(issue["id"]) for issue in response["issues"]], response.get(
|
||||
"nextPageToken"
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
]
|
||||
|
||||
|
||||
def _perform_jql_search_v3(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
max_results: int,
|
||||
all_issue_ids: list[list[str]],
|
||||
fields: str | None = None,
|
||||
checkpoint_callback: (
|
||||
Callable[[Iterator[list[str]], str | None], None] | None
|
||||
) = None,
|
||||
nextPageToken: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
"""
|
||||
The way this works is we get all the issue ids and bulk fetch them in batches.
|
||||
However, for really large deployments we can't do these operations sequentially,
|
||||
as it might take several hours to fetch all the issue ids.
|
||||
|
||||
So, each run of this function does at least one of:
|
||||
- fetch a batch of issue ids
|
||||
- bulk fetch a batch of issues
|
||||
|
||||
If all_issue_ids is not None, we use it to bulk fetch issues.
|
||||
"""
|
||||
|
||||
# with some careful synchronization these steps can be done in parallel,
|
||||
# leaving that out for now to avoid rate limit issues
|
||||
if not ids_done:
|
||||
new_ids, pageToken = enhanced_search_ids(jira_client, jql, nextPageToken)
|
||||
if checkpoint_callback is not None:
|
||||
checkpoint_callback(chunked(new_ids, max_results), pageToken)
|
||||
|
||||
# bulk fetch issues from ids. Note that the above callback MAY mutate all_issue_ids,
|
||||
# but this fetch always just takes the last id batch.
|
||||
if all_issue_ids:
|
||||
yield from bulk_fetch_issues(jira_client, all_issue_ids.pop(), fields)
|
||||
|
||||
|
||||
def _perform_jql_search_v2(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
"""
|
||||
Unfortunately, jira server/data center will forever use the v2 APIs that are now deprecated.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
@@ -202,6 +350,12 @@ def process_jira_issue(
|
||||
|
||||
|
||||
class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# used for v3 (cloud) endpoint
|
||||
all_issue_ids: list[list[str]] = []
|
||||
ids_done: bool = False
|
||||
cursor: str | None = None
|
||||
# deprecated
|
||||
# Used for v2 endpoint (server/data center)
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
@@ -301,12 +455,19 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
current_offset = starting_offset
|
||||
new_checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
checkpoint_callback = make_checkpoint_callback(new_checkpoint)
|
||||
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_FULL_PAGE_SIZE,
|
||||
all_issue_ids=new_checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=new_checkpoint.cursor,
|
||||
ids_done=new_checkpoint.ids_done,
|
||||
):
|
||||
issue_key = issue.key
|
||||
try:
|
||||
@@ -331,12 +492,28 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
current_offset += 1
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint = JiraConnectorCheckpoint(
|
||||
offset=current_offset,
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
|
||||
self.update_checkpoint_for_next_run(
|
||||
new_checkpoint, current_offset, starting_offset, _JIRA_FULL_PAGE_SIZE
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
def update_checkpoint_for_next_run(
|
||||
self,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
current_offset: int,
|
||||
starting_offset: int,
|
||||
page_size: int,
|
||||
) -> None:
|
||||
if _is_cloud_client(self.jira_client):
|
||||
# other updates done in the checkpoint callback
|
||||
checkpoint.has_more = (
|
||||
len(checkpoint.all_issue_ids) > 0 or not checkpoint.ids_done
|
||||
)
|
||||
else:
|
||||
checkpoint.offset = current_offset
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
checkpoint.has_more = current_offset - starting_offset == page_size
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -352,33 +529,47 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
) # we add one day to account for any potential timezone issues
|
||||
|
||||
jql = self._get_jql_query(start, end)
|
||||
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
checkpoint_callback = make_checkpoint_callback(checkpoint)
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_doc_batch = []
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=int(start),
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
):
|
||||
project_key = get_jira_project_key_from_issue(issue=issue)
|
||||
if not project_key:
|
||||
continue
|
||||
while checkpoint.has_more:
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=checkpoint.cursor,
|
||||
ids_done=checkpoint.ids_done,
|
||||
):
|
||||
project_key = get_jira_project_key_from_issue(issue=issue)
|
||||
if not project_key:
|
||||
continue
|
||||
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
external_access=get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
),
|
||||
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
|
||||
id = build_jira_url(self.jira_client, issue_key)
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=id,
|
||||
external_access=get_project_permissions(
|
||||
jira_client=self.jira_client, jira_project=project_key
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
self.update_checkpoint_for_next_run(
|
||||
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
|
||||
)
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
prev_offset = current_offset
|
||||
|
||||
yield slim_doc_batch
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._jira_client is None:
|
||||
@@ -471,6 +662,21 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
)
|
||||
|
||||
|
||||
def make_checkpoint_callback(
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> Callable[[Iterator[list[str]], str | None], None]:
|
||||
def checkpoint_callback(
|
||||
issue_ids: Iterator[list[str]], pageToken: str | None
|
||||
) -> None:
|
||||
for id_batch in issue_ids:
|
||||
checkpoint.all_issue_ids.append(id_batch)
|
||||
checkpoint.cursor = pageToken
|
||||
# pageToken starts out as None and is only None once we've fetched all the issue ids
|
||||
checkpoint.ids_done = pageToken is None
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
@@ -17,7 +17,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
PROJECT_URL_PAT = "projects"
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION") or "2"
|
||||
JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION") or "3"
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
|
||||
return JIRA(
|
||||
token_auth=api_token,
|
||||
server=jira_base,
|
||||
options={"rest_api_version": JIRA_API_VERSION},
|
||||
options={"rest_api_version": JIRA_SERVER_API_VERSION},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -196,7 +196,16 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except Exception as e:
|
||||
if res.json().get("code") == "object_not_found":
|
||||
json_data = res.json()
|
||||
code = json_data.get("code")
|
||||
# Sep 3 2025 backend changed the error message for this case
|
||||
# TODO: it is also now possible for there to be multiple data sources per database; at present we
|
||||
# just don't handle that. We will need to upgrade the API to the current version + query the
|
||||
# new data sources endpoint to handle that case correctly.
|
||||
if code == "object_not_found" or (
|
||||
code == "validation_error"
|
||||
and "does not contain any data sources" in json_data.get("message", "")
|
||||
):
|
||||
# this happens when a database is not shared with the integration
|
||||
# in this case, we should just ignore the database
|
||||
logger.error(
|
||||
@@ -213,42 +222,49 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def _properties_to_str(properties: dict[str, Any]) -> str:
|
||||
"""Converts Notion properties to a string"""
|
||||
|
||||
def _recurse_list_properties(inner_list: list[Any]) -> str | None:
|
||||
list_properties: list[str | None] = []
|
||||
for item in inner_list:
|
||||
if item and isinstance(item, dict):
|
||||
list_properties.append(_recurse_properties(item))
|
||||
elif item and isinstance(item, list):
|
||||
list_properties.append(_recurse_list_properties(item))
|
||||
else:
|
||||
list_properties.append(str(item))
|
||||
return (
|
||||
", ".join(
|
||||
[
|
||||
list_property
|
||||
for list_property in list_properties
|
||||
if list_property
|
||||
]
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
|
||||
while "type" in inner_dict:
|
||||
type_name = inner_dict["type"]
|
||||
inner_dict = inner_dict[type_name]
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
|
||||
type_name = sub_inner_dict["type"]
|
||||
sub_inner_dict = sub_inner_dict[type_name]
|
||||
|
||||
# If the innermost layer is None, the value is not set
|
||||
if not inner_dict:
|
||||
if not sub_inner_dict:
|
||||
return None
|
||||
|
||||
if isinstance(inner_dict, list):
|
||||
list_properties = [
|
||||
_recurse_properties(item) for item in inner_dict if item
|
||||
]
|
||||
return (
|
||||
", ".join(
|
||||
[
|
||||
list_property
|
||||
for list_property in list_properties
|
||||
if list_property
|
||||
]
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
# TODO there may be more types to handle here
|
||||
if isinstance(inner_dict, str):
|
||||
if isinstance(sub_inner_dict, list):
|
||||
return _recurse_list_properties(sub_inner_dict)
|
||||
elif isinstance(sub_inner_dict, str):
|
||||
# For some objects the innermost value could just be a string, not sure what causes this
|
||||
return inner_dict
|
||||
|
||||
elif isinstance(inner_dict, dict):
|
||||
if "name" in inner_dict:
|
||||
return inner_dict["name"]
|
||||
if "content" in inner_dict:
|
||||
return inner_dict["content"]
|
||||
start = inner_dict.get("start")
|
||||
end = inner_dict.get("end")
|
||||
return sub_inner_dict
|
||||
elif isinstance(sub_inner_dict, dict):
|
||||
if "name" in sub_inner_dict:
|
||||
return sub_inner_dict["name"]
|
||||
if "content" in sub_inner_dict:
|
||||
return sub_inner_dict["content"]
|
||||
start = sub_inner_dict.get("start")
|
||||
end = sub_inner_dict.get("end")
|
||||
if start is not None:
|
||||
if end is not None:
|
||||
return f"{start} - {end}"
|
||||
@@ -256,13 +272,13 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
elif end is not None:
|
||||
return f"Until {end}"
|
||||
|
||||
if "id" in inner_dict:
|
||||
if "id" in sub_inner_dict:
|
||||
# This is not useful to index, it's a reference to another Notion object
|
||||
# and this ID value in plaintext is useless outside of the Notion context
|
||||
logger.debug("Skipping Notion object id field property")
|
||||
return None
|
||||
|
||||
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
|
||||
logger.debug(f"Unreadable property from innermost prop: {sub_inner_dict}")
|
||||
return None
|
||||
|
||||
result = ""
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
@@ -920,8 +921,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Get custom fields for parent type
|
||||
field_set = set(custom_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
# used during doc conversion
|
||||
# field_set.add(NAME_FIELD) # does not always exist
|
||||
field_set.add(ID_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
|
||||
# Use only the specified fields
|
||||
@@ -985,7 +987,8 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
field_set = set(config_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
# field_set.add(NAME_FIELD) # does not always exist
|
||||
field_set.add(ID_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
queryable_fields = field_set
|
||||
else:
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
@@ -169,7 +170,9 @@ def convert_sf_query_result_to_doc(
|
||||
|
||||
base_url = f"https://{sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
|
||||
extracted_semantic_identifier = record.get(NAME_FIELD) or record.get(
|
||||
ID_FIELD, "Unknown Object"
|
||||
)
|
||||
|
||||
sections = [_extract_section(record, f"{base_url}/{record_id}")]
|
||||
for child_record_key, child_record in child_records.items():
|
||||
@@ -204,11 +207,13 @@ def convert_sf_object_to_doc(
|
||||
) -> Document:
|
||||
"""Would be nice if this function was documented"""
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
salesforce_id = object_dict[ID_FIELD]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
|
||||
extracted_semantic_identifier = object_dict.get(NAME_FIELD) or object_dict.get(
|
||||
ID_FIELD, "Unknown Object"
|
||||
)
|
||||
|
||||
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
|
||||
for id in sf_db.get_child_ids(sf_object.id):
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
|
||||
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
|
||||
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -66,7 +67,7 @@ class OnyxSalesforce(Salesforce):
|
||||
return False
|
||||
|
||||
@retry_builder(
|
||||
tries=5,
|
||||
tries=6,
|
||||
delay=20,
|
||||
backoff=1.5,
|
||||
max_delay=60,
|
||||
@@ -218,7 +219,7 @@ class OnyxSalesforce(Salesforce):
|
||||
continue
|
||||
|
||||
for child_record in child_result["records"]:
|
||||
child_record_id = child_record["Id"]
|
||||
child_record_id = child_record[ID_FIELD]
|
||||
if not child_record_id:
|
||||
logger.warning("Child record has no id")
|
||||
continue
|
||||
|
||||
@@ -6,9 +6,11 @@ import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
@@ -408,7 +410,7 @@ class OnyxSalesforceSQLite:
|
||||
# if this id is a parent type, yield it directly
|
||||
if changed_type in parent_types:
|
||||
yield changed_id, changed_type, num_examined
|
||||
changed_parent_ids.update(changed_id)
|
||||
changed_parent_ids.add(changed_id)
|
||||
continue
|
||||
|
||||
# if this id is a child type, then check the columns
|
||||
@@ -431,7 +433,7 @@ class OnyxSalesforceSQLite:
|
||||
logger.warning(f"{field_name=} not in data for {changed_type=}!")
|
||||
continue
|
||||
|
||||
parent_id = sf_object.data[field_name]
|
||||
parent_id = cast(str, sf_object.data[field_name])
|
||||
parent_id_prefix = parent_id[:3]
|
||||
|
||||
if parent_id_prefix not in prefix_to_type:
|
||||
@@ -445,7 +447,7 @@ class OnyxSalesforceSQLite:
|
||||
continue
|
||||
|
||||
yield parent_id, parent_type, num_examined
|
||||
changed_parent_ids.update(parent_id)
|
||||
changed_parent_ids.add(parent_id)
|
||||
break
|
||||
|
||||
def object_type_count(self, object_type: str) -> int:
|
||||
@@ -498,7 +500,7 @@ class OnyxSalesforceSQLite:
|
||||
|
||||
# remove salesforce id's (and add to parent id set)
|
||||
if (
|
||||
field != "Id"
|
||||
field != ID_FIELD
|
||||
and isinstance(value, str)
|
||||
and validate_salesforce_id(value)
|
||||
):
|
||||
@@ -535,13 +537,13 @@ class OnyxSalesforceSQLite:
|
||||
reader = csv.DictReader(f)
|
||||
uncommitted_rows = 0
|
||||
for row in reader:
|
||||
if "Id" not in row:
|
||||
if ID_FIELD not in row:
|
||||
logger.warning(
|
||||
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||
f"Row {row} does not have an {ID_FIELD} field in {csv_download_path}"
|
||||
)
|
||||
continue
|
||||
|
||||
row_id = row["Id"]
|
||||
row_id = row[ID_FIELD]
|
||||
|
||||
normalized_record, parent_ids = (
|
||||
OnyxSalesforceSQLite.normalize_record(row, remove_ids)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
NAME_FIELD = "Name"
|
||||
MODIFIED_FIELD = "LastModifiedDate"
|
||||
ID_FIELD = "Id"
|
||||
ACCOUNT_OBJECT_TYPE = "Account"
|
||||
USER_OBJECT_TYPE = "User"
|
||||
|
||||
@@ -24,7 +25,7 @@ class SalesforceObject:
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||
return cls(
|
||||
id=data["Id"],
|
||||
id=data[ID_FIELD],
|
||||
type=data["Type"],
|
||||
data=data,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user