mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
106 Commits
v2.12.0-cl
...
colours
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68e19383b9 | ||
|
|
f01f3d6e77 | ||
|
|
42b57bf804 | ||
|
|
d4b696d17d | ||
|
|
84b774f376 | ||
|
|
921f90d013 | ||
|
|
f6c09572eb | ||
|
|
b2f90db0b8 | ||
|
|
ee7b33b382 | ||
|
|
db37c35030 | ||
|
|
4f105d002c | ||
|
|
7e9095b976 | ||
|
|
bce6741ee6 | ||
|
|
b5a188ee5e | ||
|
|
8578d89a24 | ||
|
|
e42840775a | ||
|
|
c907971d5b | ||
|
|
44a565bbfc | ||
|
|
ea4219806d | ||
|
|
ccc76413c6 | ||
|
|
12a2786ff6 | ||
|
|
882c294e74 | ||
|
|
0e2b6cf193 | ||
|
|
c7ae8bd783 | ||
|
|
3b10dd4b22 | ||
|
|
6ea886fc85 | ||
|
|
f8c89fc750 | ||
|
|
cace80ffaa | ||
|
|
e628033885 | ||
|
|
22bb4b6d98 | ||
|
|
9afffc2de4 | ||
|
|
2c1193f975 | ||
|
|
b192542c85 | ||
|
|
d8821b8ccc | ||
|
|
a007369bd5 | ||
|
|
2f65629f51 | ||
|
|
7701ae2112 | ||
|
|
01a3a256e9 | ||
|
|
0d55febaa7 | ||
|
|
bdafbfe0e8 | ||
|
|
278fd0e153 | ||
|
|
a4bb97bc22 | ||
|
|
8063d9a75e | ||
|
|
1ffaba12f0 | ||
|
|
26f8660663 | ||
|
|
d6504ed578 | ||
|
|
7fcc2c9d35 | ||
|
|
46e8f925fe | ||
|
|
5ec1f61839 | ||
|
|
df950963a7 | ||
|
|
93208a66ac | ||
|
|
a4819e07e7 | ||
|
|
f642ace40c | ||
|
|
9b430ae2d5 | ||
|
|
05f3f878b2 | ||
|
|
df17c5352e | ||
|
|
bcfb0f3cf3 | ||
|
|
38468c1dc4 | ||
|
|
8550a9c5e3 | ||
|
|
fe0c60e50d | ||
|
|
4ecc151a02 | ||
|
|
d08becead5 | ||
|
|
a429f852d5 | ||
|
|
a856f27fae | ||
|
|
d0d8027928 | ||
|
|
bd1671f1a1 | ||
|
|
e236c67678 | ||
|
|
683956697a | ||
|
|
fb1e303ffc | ||
|
|
729d4fafd1 | ||
|
|
40c60282d0 | ||
|
|
2141fd2c6e | ||
|
|
9aeba96043 | ||
|
|
b431de5141 | ||
|
|
d1a6340cfc | ||
|
|
ccf382ef4f | ||
|
|
c31997b9b2 | ||
|
|
ab31795a46 | ||
|
|
b3beca63dc | ||
|
|
cc6d54c1e6 | ||
|
|
ee12c0c5de | ||
|
|
d48912a05d | ||
|
|
c079072676 | ||
|
|
952f6bfb37 | ||
|
|
0714e4bb4e | ||
|
|
ae577f0f44 | ||
|
|
0705d584d8 | ||
|
|
36e391e557 | ||
|
|
1efce594b5 | ||
|
|
67ac53f17d | ||
|
|
d5a222925a | ||
|
|
d5ef928782 | ||
|
|
6963d78f8e | ||
|
|
d3ef2b8c17 | ||
|
|
70f4162ea8 | ||
|
|
883f52d332 | ||
|
|
f8fd83c883 | ||
|
|
d2bf0c0c5f | ||
|
|
5d598c2d22 | ||
|
|
9dc0e97302 | ||
|
|
048b2a6b39 | ||
|
|
7dd3cecf67 | ||
|
|
82abe28986 | ||
|
|
a0575e6a00 | ||
|
|
0c5bf5b3ed | ||
|
|
492117d910 |
5
.github/pull_request_template.md
vendored
5
.github/pull_request_template.md
vendored
@@ -6,9 +6,6 @@
|
||||
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
## Additional Options
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
24
.github/workflows/check-lazy-imports.yml
vendored
Normal file
24
.github/workflows/check-lazy-imports.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Check Lazy Imports
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
3
.github/workflows/helm-chart-releases.yml
vendored
3
.github/workflows/helm-chart-releases.yml
vendored
@@ -27,6 +27,7 @@ jobs:
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
@@ -46,4 +47,4 @@ jobs:
|
||||
charts_dir: deployment/helm/charts
|
||||
branch: gh-pages
|
||||
commit_username: ${{ github.actor }}
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
|
||||
@@ -23,6 +23,7 @@ env:
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -42,8 +43,8 @@ jobs:
|
||||
|
||||
external-dependency-unit-tests:
|
||||
needs: discover-test-dirs
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
# Use larger runner with more resources for Vespa
|
||||
runs-on: [runs-on, runner=16cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -52,6 +53,7 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@@ -77,19 +79,30 @@ jobs:
|
||||
- name: Set up Standard Dependencies
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
|
||||
|
||||
- name: Wait for services
|
||||
run: |
|
||||
echo "Waiting for services to be ready..."
|
||||
sleep 30
|
||||
|
||||
# Wait for Vespa specifically
|
||||
echo "Waiting for Vespa to be ready..."
|
||||
timeout 300 bash -c 'until curl -f -s http://localhost:8081/ApplicationStatus > /dev/null 2>&1; do echo "Vespa not ready, waiting..."; sleep 10; done' || echo "Vespa timeout - continuing anyway"
|
||||
|
||||
echo "Services should be ready now"
|
||||
|
||||
- name: Run migrations
|
||||
run: |
|
||||
cd backend
|
||||
# Run migrations to head
|
||||
alembic upgrade head
|
||||
alembic heads --verbose
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
|
||||
1
.github/workflows/pr-helm-chart-testing.yml
vendored
1
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -169,6 +169,7 @@ jobs:
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
38
.github/workflows/pr-integration-tests.yml
vendored
38
.github/workflows/pr-integration-tests.yml
vendored
@@ -130,6 +130,7 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -189,6 +190,7 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -230,9 +232,9 @@ jobs:
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
@@ -257,7 +259,7 @@ jobs:
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
@@ -273,7 +275,7 @@ jobs:
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -317,7 +319,7 @@ jobs:
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
@@ -354,13 +356,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -374,7 +376,7 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
@@ -405,9 +407,9 @@ jobs:
|
||||
|
||||
- name: Pull Docker images
|
||||
run: |
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
wait
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
|
||||
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
|
||||
@@ -423,7 +425,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
@@ -438,7 +440,7 @@ jobs:
|
||||
- name: Wait for service to be ready (multi-tenant)
|
||||
run: |
|
||||
echo "Starting wait-for-service script for multi-tenant..."
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -464,7 +466,7 @@ jobs:
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Running multi-tenant integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
@@ -493,13 +495,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
|
||||
docker compose -f docker-compose.multitenant-dev.yml logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
|
||||
|
||||
- name: Dump all-container logs (multi-tenant)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
|
||||
docker compose -f docker-compose.multitenant-dev.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
@@ -512,7 +514,7 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.multitenant-dev.yml down -v
|
||||
|
||||
required:
|
||||
runs-on: blacksmith-2vcpu-ubuntu-2404-arm
|
||||
|
||||
20
.github/workflows/pr-mit-integration-tests.yml
vendored
20
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -127,6 +127,7 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -186,6 +187,7 @@ jobs:
|
||||
platforms: linux/arm64
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -228,9 +230,9 @@ jobs:
|
||||
# Pull all images from registry in parallel
|
||||
echo "Pulling Docker images in parallel..."
|
||||
# Pull images from private registry
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
|
||||
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
|
||||
|
||||
# Wait for all background jobs to complete
|
||||
wait
|
||||
@@ -253,7 +255,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
@@ -269,7 +271,7 @@ jobs:
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -314,7 +316,7 @@ jobs:
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
@@ -351,13 +353,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -371,7 +373,7 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
|
||||
required:
|
||||
|
||||
8
.github/workflows/pr-playwright-tests.yml
vendored
8
.github/workflows/pr-playwright-tests.yml
vendored
@@ -189,14 +189,14 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
docker logs -f onyx-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
@@ -257,7 +257,7 @@ jobs:
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose down -v
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
|
||||
@@ -96,6 +96,13 @@ env:
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
# Bitbucket
|
||||
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
|
||||
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
|
||||
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
|
||||
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
6
.github/workflows/pr-python-model-tests.yml
vendored
6
.github/workflows/pr-python-model-tests.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -145,5 +145,5 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.model-server-test.yml down -v
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -31,12 +31,14 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -17,6 +17,7 @@ backend/tests/regression/answer_quality/test_data.json
|
||||
backend/tests/regression/search_quality/eval-*
|
||||
backend/tests/regression/search_quality/search_eval_config.yaml
|
||||
backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
*.log
|
||||
|
||||
# secret files
|
||||
@@ -28,6 +29,7 @@ settings.json
|
||||
/deployment/data/nginx/app.conf
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
*.egg-info
|
||||
|
||||
# Local .terraform directories
|
||||
**/.terraform/*
|
||||
|
||||
8
.mcp.json.template
Normal file
8
.mcp.json.template
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"onyx-mcp": {
|
||||
"type": "http",
|
||||
"url": "http://localhost:8000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,15 @@ repos:
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports are not directly imported
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/.*\.py$
|
||||
pass_filenames: false
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
# and keep in sync all dependencies so mypy has access to the appropriate type
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -10,7 +10,7 @@ SKIP_WARM_UP=True
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
LOG_DANSWER_MODEL_INTERACTIONS=True
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
@@ -39,8 +39,8 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
|
||||
# Python stuff
|
||||
|
||||
1143
.vscode/launch.template.jsonc
vendored
1143
.vscode/launch.template.jsonc
vendored
File diff suppressed because it is too large
Load Diff
@@ -4,14 +4,14 @@ This file provides guidance to Codex when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
@@ -4,14 +4,14 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
|
||||
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
|
||||
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
|
||||
outside of those directories.
|
||||
|
||||
@@ -84,10 +84,6 @@ python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> This virtual environment MUST NOT be set up WITHIN the onyx directory if you plan on using mypy within certain IDEs.
|
||||
> For simplicity, we recommend setting up the virtual environment outside of the onyx directory.
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
@@ -175,7 +171,7 @@ You will need Docker installed to run these containers.
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
|
||||
docker compose up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
@@ -257,7 +253,7 @@ You can run the full Onyx application stack from pre-built images including all
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
@@ -265,7 +261,7 @@ After Docker pulls and starts these containers, navigate to `http://localhost:30
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
|
||||
134
README.md
134
README.md
@@ -1,117 +1,103 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
<p align="center">Open Source AI Platform</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/website?url=https://www.onyx.app&up_message=visit&up_color=blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/LICENSE" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
|
||||
<h3>Feature Highlights</h3>
|
||||
**[Onyx](https://www.onyx.app/)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep Research, Connectors to 40+ knowledge sources, and more.
|
||||
|
||||
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
|
||||
> [!TIP]
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
|
||||
> ```
|
||||
|
||||
|
||||
**Use Onyx as a secure AI Chat with any LLM:**
|
||||
****
|
||||
|
||||

|
||||
|
||||
|
||||
**Easily set up connectors to your apps:**
|
||||
|
||||

|
||||
## ⭐ Features
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge and actions.
|
||||
- **🌍 Web Search:** Browse the web with Google PSE, Exa, and Serper as well as an in-house scraper or Firecrawl.
|
||||
- **🔍 RAG:** Best in class hybrid-search + knowledge graph for uploaded files and ingested documents from connectors.
|
||||
- **🔄 Connectors:** Pull knowledge, metadata, and access information from over 40 applications.
|
||||
- **🔬 Deep Research:** Get in depth answers with an agentic multi-step search.
|
||||
- **▶️ Actions & MCP:** Give AI Agents the ability to interact with external systems.
|
||||
- **💻 Code Interpreter:** Execute code to analyze data, render graphs and create files.
|
||||
- **🎨 Image Generation:** Generate images based on user prompts.
|
||||
- **👥 Collaboration:** Chat sharing, feedback gathering, user management, usage analytics, and more.
|
||||
|
||||
Onyx works with all LLMs (like OpenAI, Anthropic, Gemini, etc.) and self-hosted LLMs (like Ollama, vLLM, etc.)
|
||||
|
||||
To learn more about the features, check out our [documentation](https://docs.onyx.app/welcome)!
|
||||
|
||||
|
||||
**Access Onyx where your team already works:**
|
||||
|
||||

|
||||
## 🚀 Deployment
|
||||
Onyx supports deployments in Docker, Kubernetes, Terraform, along with guides for major cloud providers.
|
||||
|
||||
See guides below:
|
||||
- [Docker](https://docs.onyx.app/deployment/local/docker) or [Quickstart](https://docs.onyx.app/deployment/getting_started/quickstart) (best for most users)
|
||||
- [Kubernetes](https://docs.onyx.app/deployment/local/kubernetes) (best for large teams)
|
||||
- [Terraform](https://docs.onyx.app/deployment/local/terraform) (best for teams already using Terraform)
|
||||
- Cloud specific guides (best if specifically using [AWS EKS](https://docs.onyx.app/deployment/cloud/aws/eks), [Azure VMs](https://docs.onyx.app/deployment/cloud/azure), etc.)
|
||||
|
||||
> [!TIP]
|
||||
> **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/deployment/getting_started/quickstart) to learn more.
|
||||
## 🔍 Other Notable Benefits
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
- **Enterprise Search**: far more than simple RAG, Onyx has custom indexing and retrieval that remains performant and accurate for scales of up to tens of millions of documents.
|
||||
- **Security**: SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- **Management UI**: different user roles such as basic, curator, and admin.
|
||||
- **Document Permissioning**: mirrors user access from external apps for RAG use cases.
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
To see ongoing and upcoming projects, check out our [roadmap](https://github.com/orgs/onyx-dot-app/projects/2)!
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Slack
|
||||
- Gmail
|
||||
- Salesforce
|
||||
- Microsoft Sharepoint
|
||||
- Github
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gong
|
||||
- Microsoft Teams
|
||||
- Dropbox
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/admin/connectors/overview).
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing).
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
|
||||
## 👪 Community
|
||||
Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**!
|
||||
|
||||
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
|
||||
@@ -0,0 +1,389 @@
|
||||
"""Migration 2: User file data preparation and backfill
|
||||
|
||||
Revision ID: 0cd424f32b1d
|
||||
Revises: 9b66d3156fc6
|
||||
Create Date: 2025-09-22 09:44:42.727034
|
||||
|
||||
This migration populates the new columns added in migration 1.
|
||||
It prepares data for the UUID transition and relationship migration.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0cd424f32b1d"
|
||||
down_revision = "9b66d3156fc6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Populate new columns with data."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === Step 1: Populate user_file.new_id ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id:
|
||||
logger.info("Populating user_file.new_id with UUIDs...")
|
||||
|
||||
# Count rows needing UUIDs
|
||||
null_count = bind.execute(
|
||||
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Generating UUIDs for {null_count} user_file records...")
|
||||
|
||||
# Populate in batches to avoid long locks
|
||||
batch_size = 10000
|
||||
total_updated = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_file
|
||||
SET new_id = gen_random_uuid()
|
||||
WHERE new_id IS NULL
|
||||
AND id IN (
|
||||
SELECT id FROM user_file
|
||||
WHERE new_id IS NULL
|
||||
LIMIT :batch_size
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
updated = result.rowcount
|
||||
total_updated += updated
|
||||
|
||||
if updated < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Updated {total_updated}/{null_count} records...")
|
||||
|
||||
logger.info(f"Generated UUIDs for {total_updated} user_file records")
|
||||
|
||||
# Verify all records have UUIDs
|
||||
remaining_null = bind.execute(
|
||||
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
raise Exception(
|
||||
f"Failed to populate all user_file.new_id values ({remaining_null} NULL)"
|
||||
)
|
||||
|
||||
# Lock down the column
|
||||
op.alter_column("user_file", "new_id", nullable=False)
|
||||
op.alter_column("user_file", "new_id", server_default=None)
|
||||
logger.info("Locked down user_file.new_id column")
|
||||
|
||||
# === Step 2: Populate persona__user_file.user_file_id_uuid ===
|
||||
persona_user_file_columns = [
|
||||
col["name"] for col in inspector.get_columns("persona__user_file")
|
||||
]
|
||||
|
||||
if has_new_id and "user_file_id_uuid" in persona_user_file_columns:
|
||||
logger.info("Populating persona__user_file.user_file_id_uuid...")
|
||||
|
||||
# Count rows needing update
|
||||
null_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM persona__user_file
|
||||
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Updating {null_count} persona__user_file records...")
|
||||
|
||||
# Update in batches
|
||||
batch_size = 10000
|
||||
total_updated = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE persona__user_file p
|
||||
SET user_file_id_uuid = uf.new_id
|
||||
FROM user_file uf
|
||||
WHERE p.user_file_id = uf.id
|
||||
AND p.user_file_id_uuid IS NULL
|
||||
AND p.persona_id IN (
|
||||
SELECT persona_id
|
||||
FROM persona__user_file
|
||||
WHERE user_file_id_uuid IS NULL
|
||||
LIMIT :batch_size
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
updated = result.rowcount
|
||||
total_updated += updated
|
||||
|
||||
if updated < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Updated {total_updated}/{null_count} records...")
|
||||
|
||||
logger.info(f"Updated {total_updated} persona__user_file records")
|
||||
|
||||
# Verify all records are populated
|
||||
remaining_null = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM persona__user_file
|
||||
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
raise Exception(
|
||||
f"Failed to populate all persona__user_file.user_file_id_uuid values ({remaining_null} NULL)"
|
||||
)
|
||||
|
||||
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=False)
|
||||
logger.info("Locked down persona__user_file.user_file_id_uuid column")
|
||||
|
||||
# === Step 3: Create user_project records from chat_folder ===
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
logger.info("Creating user_project records from chat_folder...")
|
||||
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO user_project (user_id, name)
|
||||
SELECT cf.user_id, cf.name
|
||||
FROM chat_folder cf
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_project up
|
||||
WHERE up.user_id = cf.user_id AND up.name = cf.name
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Created {result.rowcount} user_project records from chat_folder")
|
||||
|
||||
# === Step 4: Populate chat_session.project_id ===
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
|
||||
if "folder_id" in chat_session_columns and "project_id" in chat_session_columns:
|
||||
logger.info("Populating chat_session.project_id...")
|
||||
|
||||
# Count sessions needing update
|
||||
null_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE project_id IS NULL AND folder_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if null_count > 0:
|
||||
logger.info(f"Updating {null_count} chat_session records...")
|
||||
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_session cs
|
||||
SET project_id = up.id
|
||||
FROM chat_folder cf
|
||||
JOIN user_project up ON up.user_id = cf.user_id AND up.name = cf.name
|
||||
WHERE cs.folder_id = cf.id AND cs.project_id IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Updated {result.rowcount} chat_session records")
|
||||
|
||||
# Verify all records are populated
|
||||
remaining_null = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE project_id IS NULL AND folder_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if remaining_null > 0:
|
||||
logger.warning(
|
||||
f"Warning: {remaining_null} chat_session records could not be mapped to projects"
|
||||
)
|
||||
|
||||
# === Step 5: Update plaintext FileRecord IDs/display names to UUID scheme ===
|
||||
# Prior to UUID migration, plaintext cache files were stored with file_id like 'plain_text_<int_id>'.
|
||||
# After migration, we use 'plaintext_<uuid>' (note the name change to 'plaintext_').
|
||||
# This step remaps existing FileRecord rows to the new naming while preserving object_key/bucket.
|
||||
logger.info("Updating plaintext FileRecord ids and display names to UUID scheme...")
|
||||
|
||||
# Count legacy plaintext records that can be mapped to UUID user_file ids
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM file_record fr
|
||||
JOIN user_file uf ON fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
"""
|
||||
)
|
||||
legacy_count = bind.execute(count_query).scalar_one()
|
||||
|
||||
if legacy_count and legacy_count > 0:
|
||||
logger.info(f"Found {legacy_count} legacy plaintext file records to update")
|
||||
|
||||
# Update display_name first for readability (safe regardless of rename)
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE file_record fr
|
||||
SET display_name = CONCAT('Plaintext for user file ', uf.new_id::text)
|
||||
FROM user_file uf
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Remap file_id from 'plaintext_<int>' -> 'plaintext_<uuid>' using transitional new_id
|
||||
# Use a single UPDATE ... WHERE file_id LIKE 'plain_text_%'
|
||||
# and ensure it aligns to existing user_file ids to avoid renaming unrelated rows
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE file_record fr
|
||||
SET file_id = CONCAT('plaintext_', uf.new_id::text)
|
||||
FROM user_file uf
|
||||
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
|
||||
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {result.rowcount} plaintext file_record ids to UUID scheme"
|
||||
)
|
||||
|
||||
# === Step 6: Ensure document_id_migrated default TRUE and backfill existing FALSE ===
|
||||
# New records should default to migrated=True so the migration task won't run for them.
|
||||
# Existing rows that had a legacy document_id should be marked as not migrated to be processed.
|
||||
|
||||
# Backfill existing records: if document_id is not null, set to FALSE
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE user_file
|
||||
SET document_id_migrated = FALSE
|
||||
WHERE document_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# === Step 7: Backfill user_file.status from index_attempt ===
|
||||
logger.info("Backfilling user_file.status from index_attempt...")
|
||||
|
||||
# Update user_file status based on latest index attempt
|
||||
# Using CTEs instead of temp tables for asyncpg compatibility
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
WITH latest_attempt AS (
|
||||
SELECT DISTINCT ON (ia.connector_credential_pair_id)
|
||||
ia.connector_credential_pair_id,
|
||||
ia.status
|
||||
FROM index_attempt ia
|
||||
ORDER BY ia.connector_credential_pair_id, ia.time_updated DESC
|
||||
),
|
||||
uf_to_ccp AS (
|
||||
SELECT DISTINCT uf.id AS uf_id, ccp.id AS cc_pair_id
|
||||
FROM user_file uf
|
||||
JOIN document_by_connector_credential_pair dcc
|
||||
ON dcc.id = REPLACE(uf.document_id, 'USER_FILE_CONNECTOR__', 'FILE_CONNECTOR__')
|
||||
JOIN connector_credential_pair ccp
|
||||
ON ccp.connector_id = dcc.connector_id
|
||||
AND ccp.credential_id = dcc.credential_id
|
||||
)
|
||||
UPDATE user_file uf
|
||||
SET status = CASE
|
||||
WHEN la.status IN ('NOT_STARTED', 'IN_PROGRESS') THEN 'PROCESSING'
|
||||
WHEN la.status = 'SUCCESS' THEN 'COMPLETED'
|
||||
ELSE 'FAILED'
|
||||
END
|
||||
FROM uf_to_ccp ufc
|
||||
LEFT JOIN latest_attempt la
|
||||
ON la.connector_credential_pair_id = ufc.cc_pair_id
|
||||
WHERE uf.id = ufc.uf_id
|
||||
AND uf.status = 'PROCESSING'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Updated status for {result.rowcount} user_file records")
|
||||
|
||||
logger.info("Migration 2 (data preparation) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reset populated data to allow clean downgrade of schema."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of data preparation...")
|
||||
|
||||
# Reset user_file columns to allow nulls before data removal
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
if "new_id" in columns:
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"new_id",
|
||||
nullable=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
)
|
||||
# Optionally clear the data
|
||||
# bind.execute(text("UPDATE user_file SET new_id = NULL"))
|
||||
logger.info("Reset user_file.new_id to nullable")
|
||||
|
||||
# Reset persona__user_file.user_file_id_uuid
|
||||
if "persona__user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
|
||||
|
||||
if "user_file_id_uuid" in columns:
|
||||
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=True)
|
||||
# Optionally clear the data
|
||||
# bind.execute(text("UPDATE persona__user_file SET user_file_id_uuid = NULL"))
|
||||
logger.info("Reset persona__user_file.user_file_id_uuid to nullable")
|
||||
|
||||
# Note: We don't delete user_project records or reset chat_session.project_id
|
||||
# as these might be in use and can be handled by the schema downgrade
|
||||
|
||||
# Reset user_file.status to default
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "status" in columns:
|
||||
bind.execute(text("UPDATE user_file SET status = 'PROCESSING'"))
|
||||
logger.info("Reset user_file.status to default")
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,261 @@
|
||||
"""Migration 3: User file relationship migration
|
||||
|
||||
Revision ID: 16c37a30adf2
|
||||
Revises: 0cd424f32b1d
|
||||
Create Date: 2025-09-22 09:47:34.175596
|
||||
|
||||
This migration converts folder-based relationships to project-based relationships.
|
||||
It migrates persona__user_folder to persona__user_file and populates project__user_file.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "16c37a30adf2"
|
||||
down_revision = "0cd424f32b1d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Migrate folder-based relationships to project-based relationships."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === Step 1: Migrate persona__user_folder to persona__user_file ===
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
if "persona__user_folder" in table_names and "user_file" in table_names:
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id and "folder_id" in user_file_columns:
|
||||
logger.info(
|
||||
"Migrating persona__user_folder relationships to persona__user_file..."
|
||||
)
|
||||
|
||||
# Count relationships to migrate (asyncpg-compatible)
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM (
|
||||
SELECT DISTINCT puf.persona_id, uf.id
|
||||
FROM persona__user_folder puf
|
||||
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM persona__user_file p2
|
||||
WHERE p2.persona_id = puf.persona_id
|
||||
AND p2.user_file_id = uf.id
|
||||
)
|
||||
) AS distinct_pairs
|
||||
"""
|
||||
)
|
||||
to_migrate = bind.execute(count_query).scalar_one()
|
||||
|
||||
if to_migrate > 0:
|
||||
logger.info(f"Creating {to_migrate} persona-file relationships...")
|
||||
|
||||
# Migrate in batches to avoid memory issues
|
||||
batch_size = 10000
|
||||
total_inserted = 0
|
||||
|
||||
while True:
|
||||
# Insert batch directly using subquery (asyncpg compatible)
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO persona__user_file (persona_id, user_file_id, user_file_id_uuid)
|
||||
SELECT DISTINCT puf.persona_id, uf.id as file_id, uf.new_id
|
||||
FROM persona__user_folder puf
|
||||
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM persona__user_file p2
|
||||
WHERE p2.persona_id = puf.persona_id
|
||||
AND p2.user_file_id = uf.id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
inserted = result.rowcount
|
||||
total_inserted += inserted
|
||||
|
||||
if inserted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f" Migrated {total_inserted}/{to_migrate} relationships..."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {total_inserted} persona__user_file relationships"
|
||||
)
|
||||
|
||||
# === Step 2: Add foreign key for chat_session.project_id ===
|
||||
chat_session_fks = inspector.get_foreign_keys("chat_session")
|
||||
fk_exists = any(
|
||||
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
|
||||
)
|
||||
|
||||
if not fk_exists:
|
||||
logger.info("Adding foreign key constraint for chat_session.project_id...")
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_project_id",
|
||||
"chat_session",
|
||||
"user_project",
|
||||
["project_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added foreign key constraint")
|
||||
|
||||
# === Step 3: Populate project__user_file from user_file.folder_id ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
has_new_id = "new_id" in user_file_columns
|
||||
|
||||
if has_new_id and "folder_id" in user_file_columns:
|
||||
logger.info("Populating project__user_file from folder relationships...")
|
||||
|
||||
# Count relationships to create
|
||||
count_query = text(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM user_file uf
|
||||
WHERE uf.folder_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM project__user_file puf
|
||||
WHERE puf.project_id = uf.folder_id
|
||||
AND puf.user_file_id = uf.new_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
to_create = bind.execute(count_query).scalar_one()
|
||||
|
||||
if to_create > 0:
|
||||
logger.info(f"Creating {to_create} project-file relationships...")
|
||||
|
||||
# Insert in batches
|
||||
batch_size = 10000
|
||||
total_inserted = 0
|
||||
|
||||
while True:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO project__user_file (project_id, user_file_id)
|
||||
SELECT uf.folder_id, uf.new_id
|
||||
FROM user_file uf
|
||||
WHERE uf.folder_id IS NOT NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM project__user_file puf
|
||||
WHERE puf.project_id = uf.folder_id
|
||||
AND puf.user_file_id = uf.new_id
|
||||
)
|
||||
LIMIT :batch_size
|
||||
ON CONFLICT (project_id, user_file_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"batch_size": batch_size},
|
||||
)
|
||||
|
||||
inserted = result.rowcount
|
||||
total_inserted += inserted
|
||||
|
||||
if inserted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f" Created {total_inserted}/{to_create} relationships...")
|
||||
|
||||
logger.info(f"Created {total_inserted} project__user_file relationships")
|
||||
|
||||
# === Step 4: Create index on chat_session.project_id ===
|
||||
try:
|
||||
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
|
||||
except Exception:
|
||||
indexes = []
|
||||
|
||||
if "ix_chat_session_project_id" not in indexes:
|
||||
logger.info("Creating index on chat_session.project_id...")
|
||||
op.create_index(
|
||||
"ix_chat_session_project_id", "chat_session", ["project_id"], unique=False
|
||||
)
|
||||
logger.info("Created index")
|
||||
|
||||
logger.info("Migration 3 (relationship migration) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove migrated relationships and constraints."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of relationship migration...")
|
||||
|
||||
# Drop index on chat_session.project_id
|
||||
try:
|
||||
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
|
||||
if "ix_chat_session_project_id" in indexes:
|
||||
op.drop_index("ix_chat_session_project_id", "chat_session")
|
||||
logger.info("Dropped index on chat_session.project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Drop foreign key constraint
|
||||
try:
|
||||
chat_session_fks = inspector.get_foreign_keys("chat_session")
|
||||
fk_exists = any(
|
||||
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
|
||||
)
|
||||
if fk_exists:
|
||||
op.drop_constraint(
|
||||
"fk_chat_session_project_id", "chat_session", type_="foreignkey"
|
||||
)
|
||||
logger.info("Dropped foreign key constraint on chat_session.project_id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear project__user_file relationships (but keep the table for migration 1 to handle)
|
||||
if "project__user_file" in inspector.get_table_names():
|
||||
result = bind.execute(text("DELETE FROM project__user_file"))
|
||||
logger.info(f"Cleared {result.rowcount} records from project__user_file")
|
||||
|
||||
# Remove migrated persona__user_file relationships
|
||||
# Only remove those that came from folder relationships
|
||||
if all(
|
||||
table in inspector.get_table_names()
|
||||
for table in ["persona__user_file", "persona__user_folder", "user_file"]
|
||||
):
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "folder_id" in user_file_columns:
|
||||
result = bind.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM persona__user_file puf
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM user_file uf
|
||||
JOIN persona__user_folder puf2
|
||||
ON puf2.user_folder_id = uf.folder_id
|
||||
WHERE puf.persona_id = puf2.persona_id
|
||||
AND puf.user_file_id = uf.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f"Removed {result.rowcount} migrated persona__user_file relationships"
|
||||
)
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,218 @@
|
||||
"""Migration 6: User file schema cleanup
|
||||
|
||||
Revision ID: 2b75d0a8ffcb
|
||||
Revises: 3a78dba1080a
|
||||
Create Date: 2025-09-22 10:09:26.375377
|
||||
|
||||
This migration removes legacy columns and tables after data migration is complete.
|
||||
It should only be run after verifying all data has been successfully migrated.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2b75d0a8ffcb"
|
||||
down_revision = "3a78dba1080a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove legacy columns and tables."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting schema cleanup...")
|
||||
|
||||
# === Step 1: Verify data migration is complete ===
|
||||
logger.info("Verifying data migration completion...")
|
||||
|
||||
# Check if any chat sessions still have folder_id references
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
if "folder_id" in chat_session_columns:
|
||||
orphaned_count = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT COUNT(*) FROM chat_session
|
||||
WHERE folder_id IS NOT NULL AND project_id IS NULL
|
||||
"""
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if orphaned_count > 0:
|
||||
logger.warning(
|
||||
f"WARNING: {orphaned_count} chat_session records still have "
|
||||
f"folder_id without project_id. Proceeding anyway."
|
||||
)
|
||||
|
||||
# === Step 2: Drop chat_session.folder_id ===
|
||||
if "folder_id" in chat_session_columns:
|
||||
logger.info("Dropping chat_session.folder_id...")
|
||||
|
||||
# Drop foreign key constraint first
|
||||
op.execute(
|
||||
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk"
|
||||
)
|
||||
|
||||
# Drop the column
|
||||
op.drop_column("chat_session", "folder_id")
|
||||
logger.info("Dropped chat_session.folder_id")
|
||||
|
||||
# === Step 3: Drop persona__user_folder table ===
|
||||
if "persona__user_folder" in inspector.get_table_names():
|
||||
logger.info("Dropping persona__user_folder table...")
|
||||
|
||||
# Check for any remaining data
|
||||
remaining = bind.execute(
|
||||
text("SELECT COUNT(*) FROM persona__user_folder")
|
||||
).scalar_one()
|
||||
|
||||
if remaining > 0:
|
||||
logger.warning(
|
||||
f"WARNING: Dropping persona__user_folder with {remaining} records"
|
||||
)
|
||||
|
||||
op.drop_table("persona__user_folder")
|
||||
logger.info("Dropped persona__user_folder table")
|
||||
|
||||
# === Step 4: Drop chat_folder table ===
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
logger.info("Dropping chat_folder table...")
|
||||
|
||||
# Check for any remaining data
|
||||
remaining = bind.execute(text("SELECT COUNT(*) FROM chat_folder")).scalar_one()
|
||||
|
||||
if remaining > 0:
|
||||
logger.warning(f"WARNING: Dropping chat_folder with {remaining} records")
|
||||
|
||||
op.drop_table("chat_folder")
|
||||
logger.info("Dropped chat_folder table")
|
||||
|
||||
# === Step 5: Drop user_file legacy columns ===
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
# Drop folder_id
|
||||
if "folder_id" in user_file_columns:
|
||||
logger.info("Dropping user_file.folder_id...")
|
||||
op.drop_column("user_file", "folder_id")
|
||||
logger.info("Dropped user_file.folder_id")
|
||||
|
||||
# Drop cc_pair_id (already handled in migration 5, but be sure)
|
||||
if "cc_pair_id" in user_file_columns:
|
||||
logger.info("Dropping user_file.cc_pair_id...")
|
||||
|
||||
# Drop any remaining foreign key constraints
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN (
|
||||
SELECT conname
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class t ON c.conrelid = t.oid
|
||||
WHERE c.contype = 'f'
|
||||
AND t.relname = 'user_file'
|
||||
AND EXISTS (
|
||||
SELECT 1 FROM pg_attribute a
|
||||
WHERE a.attrelid = t.oid
|
||||
AND a.attname = 'cc_pair_id'
|
||||
)
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.drop_column("user_file", "cc_pair_id")
|
||||
logger.info("Dropped user_file.cc_pair_id")
|
||||
|
||||
# === Step 6: Clean up any remaining constraints ===
|
||||
logger.info("Cleaning up remaining constraints...")
|
||||
|
||||
# Drop any unique constraints on removed columns
|
||||
op.execute(
|
||||
"ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_cc_pair_id_key"
|
||||
)
|
||||
|
||||
logger.info("Migration 6 (schema cleanup) completed successfully")
|
||||
logger.info("Legacy schema has been fully removed")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate dropped columns and tables (structure only, no data)."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.warning("Downgrading schema cleanup - recreating structure only, no data!")
|
||||
|
||||
# Recreate user_file columns
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
if "cc_pair_id" not in columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("cc_pair_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
if "folder_id" not in columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("folder_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Recreate chat_folder table
|
||||
if "chat_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"], ["user.id"], name="chat_folder_user_fk"
|
||||
),
|
||||
)
|
||||
|
||||
# Recreate persona__user_folder table
|
||||
if "persona__user_folder" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_folder_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("persona_id", "user_folder_id"),
|
||||
sa.ForeignKeyConstraint(["persona_id"], ["persona.id"]),
|
||||
sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]),
|
||||
)
|
||||
|
||||
# Add folder_id back to chat_session
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("chat_session")]
|
||||
if "folder_id" not in columns:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key if chat_folder exists
|
||||
if "chat_folder" in inspector.get_table_names():
|
||||
op.create_foreign_key(
|
||||
"chat_session_folder_fk",
|
||||
"chat_session",
|
||||
"chat_folder",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
logger.info("Downgrade completed - structure recreated but data is lost")
|
||||
@@ -0,0 +1,295 @@
|
||||
"""Migration 5: User file legacy data cleanup
|
||||
|
||||
Revision ID: 3a78dba1080a
|
||||
Revises: 7cc3fcc116c1
|
||||
Create Date: 2025-09-22 10:04:27.986294
|
||||
|
||||
This migration removes legacy user-file documents and connector_credential_pairs.
|
||||
It performs bulk deletions of obsolete data after the UUID migration.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
from sqlalchemy import text
|
||||
import logging
|
||||
from typing import List
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a78dba1080a"
|
||||
down_revision = "7cc3fcc116c1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def batch_delete(
|
||||
bind: sa.engine.Connection,
|
||||
table_name: str,
|
||||
id_column: str,
|
||||
ids: List[str | int | uuid.UUID],
|
||||
batch_size: int = 1000,
|
||||
id_type: str = "int",
|
||||
) -> int:
|
||||
"""Delete records in batches to avoid memory issues and timeouts."""
|
||||
total_count = len(ids)
|
||||
if total_count == 0:
|
||||
return 0
|
||||
|
||||
logger.info(
|
||||
f"Starting batch deletion of {total_count} records from {table_name}..."
|
||||
)
|
||||
|
||||
# Determine appropriate ARRAY type
|
||||
if id_type == "uuid":
|
||||
array_type = psql.ARRAY(psql.UUID(as_uuid=True))
|
||||
elif id_type == "int":
|
||||
array_type = psql.ARRAY(sa.Integer())
|
||||
else:
|
||||
array_type = psql.ARRAY(sa.String())
|
||||
|
||||
total_deleted = 0
|
||||
failed_batches = []
|
||||
|
||||
for i in range(0, total_count, batch_size):
|
||||
batch_ids = ids[i : i + batch_size]
|
||||
try:
|
||||
stmt = text(
|
||||
f"DELETE FROM {table_name} WHERE {id_column} = ANY(:ids)"
|
||||
).bindparams(sa.bindparam("ids", value=batch_ids, type_=array_type))
|
||||
result = bind.execute(stmt)
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# Log progress every 10 batches or at completion
|
||||
batch_num = (i // batch_size) + 1
|
||||
if batch_num % 10 == 0 or i + batch_size >= total_count:
|
||||
logger.info(
|
||||
f" Deleted {min(i + batch_size, total_count)}/{total_count} records "
|
||||
f"({total_deleted} actual) from {table_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete batch {(i // batch_size) + 1}: {e}")
|
||||
failed_batches.append((i, min(i + batch_size, total_count)))
|
||||
|
||||
if failed_batches:
|
||||
logger.warning(
|
||||
f"Failed to delete {len(failed_batches)} batches from {table_name}. "
|
||||
f"Total deleted: {total_deleted}/{total_count}"
|
||||
)
|
||||
# Fail the migration to avoid silently succeeding on partial cleanup
|
||||
raise RuntimeError(
|
||||
f"Batch deletion failed for {table_name}: "
|
||||
f"{len(failed_batches)} failed batches out of "
|
||||
f"{(total_count + batch_size - 1) // batch_size}."
|
||||
)
|
||||
|
||||
return total_deleted
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Remove legacy user-file documents and connector_credential_pairs."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting legacy data cleanup...")
|
||||
|
||||
# === Step 1: Identify and delete user-file documents ===
|
||||
logger.info("Identifying user-file documents to delete...")
|
||||
|
||||
# Get document IDs to delete
|
||||
doc_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT dcc.id AS document_id
|
||||
FROM document_by_connector_credential_pair dcc
|
||||
JOIN connector_credential_pair u
|
||||
ON u.connector_id = dcc.connector_id
|
||||
AND u.credential_id = dcc.credential_id
|
||||
WHERE u.is_user_file IS TRUE
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
doc_ids = [r[0] for r in doc_rows]
|
||||
|
||||
if doc_ids:
|
||||
logger.info(f"Found {len(doc_ids)} user-file documents to delete")
|
||||
|
||||
# Delete dependent rows first
|
||||
tables_to_clean = [
|
||||
("document_retrieval_feedback", "document_id"),
|
||||
("document__tag", "document_id"),
|
||||
("chunk_stats", "document_id"),
|
||||
]
|
||||
|
||||
for table_name, column_name in tables_to_clean:
|
||||
if table_name in inspector.get_table_names():
|
||||
# document_id is a string in these tables
|
||||
deleted = batch_delete(
|
||||
bind, table_name, column_name, doc_ids, id_type="str"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} records from {table_name}")
|
||||
|
||||
# Delete document_by_connector_credential_pair entries
|
||||
deleted = batch_delete(
|
||||
bind, "document_by_connector_credential_pair", "id", doc_ids, id_type="str"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} document_by_connector_credential_pair records")
|
||||
|
||||
# Delete documents themselves
|
||||
deleted = batch_delete(bind, "document", "id", doc_ids, id_type="str")
|
||||
logger.info(f"Deleted {deleted} document records")
|
||||
else:
|
||||
logger.info("No user-file documents found to delete")
|
||||
|
||||
# === Step 2: Clean up user-file connector_credential_pairs ===
|
||||
logger.info("Cleaning up user-file connector_credential_pairs...")
|
||||
|
||||
# Get cc_pair IDs
|
||||
cc_pair_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id AS cc_pair_id
|
||||
FROM connector_credential_pair
|
||||
WHERE is_user_file IS TRUE
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
cc_pair_ids = [r[0] for r in cc_pair_rows]
|
||||
|
||||
if cc_pair_ids:
|
||||
logger.info(
|
||||
f"Found {len(cc_pair_ids)} user-file connector_credential_pairs to clean up"
|
||||
)
|
||||
|
||||
# Delete related records
|
||||
tables_to_clean = [
|
||||
("index_attempt", "connector_credential_pair_id"),
|
||||
("background_error", "cc_pair_id"),
|
||||
("document_set__connector_credential_pair", "connector_credential_pair_id"),
|
||||
("user_group__connector_credential_pair", "cc_pair_id"),
|
||||
]
|
||||
|
||||
for table_name, column_name in tables_to_clean:
|
||||
if table_name in inspector.get_table_names():
|
||||
deleted = batch_delete(
|
||||
bind, table_name, column_name, cc_pair_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} records from {table_name}")
|
||||
|
||||
# === Step 3: Identify connectors and credentials to delete ===
|
||||
logger.info("Identifying orphaned connectors and credentials...")
|
||||
|
||||
# Get connectors used only by user-file cc_pairs
|
||||
connector_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT ccp.connector_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ccp.is_user_file IS TRUE
|
||||
AND ccp.connector_id != 0 -- Exclude system default
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair c2
|
||||
WHERE c2.connector_id = ccp.connector_id
|
||||
AND c2.is_user_file IS NOT TRUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
userfile_only_connector_ids = [r[0] for r in connector_rows]
|
||||
|
||||
# Get credentials used only by user-file cc_pairs
|
||||
credential_rows = bind.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT DISTINCT ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ccp.is_user_file IS TRUE
|
||||
AND ccp.credential_id != 0 -- Exclude public/default
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair c2
|
||||
WHERE c2.credential_id = ccp.credential_id
|
||||
AND c2.is_user_file IS NOT TRUE
|
||||
)
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
userfile_only_credential_ids = [r[0] for r in credential_rows]
|
||||
|
||||
# === Step 4: Delete the cc_pairs themselves ===
|
||||
if cc_pair_ids:
|
||||
# Remove FK dependency from user_file first
|
||||
bind.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN (
|
||||
SELECT conname
|
||||
FROM pg_constraint c
|
||||
JOIN pg_class t ON c.conrelid = t.oid
|
||||
JOIN pg_class ft ON c.confrelid = ft.oid
|
||||
WHERE c.contype = 'f'
|
||||
AND t.relname = 'user_file'
|
||||
AND ft.relname = 'connector_credential_pair'
|
||||
) LOOP
|
||||
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
|
||||
END LOOP;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Delete cc_pairs
|
||||
deleted = batch_delete(
|
||||
bind, "connector_credential_pair", "id", cc_pair_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} connector_credential_pair records")
|
||||
|
||||
# === Step 5: Delete orphaned connectors ===
|
||||
if userfile_only_connector_ids:
|
||||
deleted = batch_delete(
|
||||
bind, "connector", "id", userfile_only_connector_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} orphaned connector records")
|
||||
|
||||
# === Step 6: Delete orphaned credentials ===
|
||||
if userfile_only_credential_ids:
|
||||
# Clean up credential__user_group mappings first
|
||||
deleted = batch_delete(
|
||||
bind,
|
||||
"credential__user_group",
|
||||
"credential_id",
|
||||
userfile_only_credential_ids,
|
||||
id_type="int",
|
||||
)
|
||||
logger.info(f"Deleted {deleted} credential__user_group records")
|
||||
|
||||
# Delete credentials
|
||||
deleted = batch_delete(
|
||||
bind, "credential", "id", userfile_only_credential_ids, id_type="int"
|
||||
)
|
||||
logger.info(f"Deleted {deleted} orphaned credential records")
|
||||
|
||||
logger.info("Migration 5 (legacy data cleanup) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Cannot restore deleted data - requires backup restoration."""
|
||||
|
||||
logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!")
|
||||
logger.error("Data restoration requires backup files or database backup.")
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of legacy data cleanup is not supported. "
|
||||
"Deleted data must be restored from backups."
|
||||
)
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Migration 4: User file UUID primary key swap
|
||||
|
||||
Revision ID: 7cc3fcc116c1
|
||||
Revises: 16c37a30adf2
|
||||
Create Date: 2025-09-22 09:54:38.292952
|
||||
|
||||
This migration performs the critical UUID primary key swap on user_file table.
|
||||
It updates all foreign key references to use UUIDs instead of integers.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7cc3fcc116c1"
|
||||
down_revision = "16c37a30adf2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Swap user_file primary key from integer to UUID."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# Verify we're in the expected state
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
if "new_id" not in user_file_columns:
|
||||
logger.warning(
|
||||
"user_file.new_id not found - migration may have already been applied"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Starting UUID primary key swap...")
|
||||
|
||||
# === Step 1: Update persona__user_file foreign key to UUID ===
|
||||
logger.info("Updating persona__user_file foreign key...")
|
||||
|
||||
# Drop existing foreign key constraints
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_uuid_fkey"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
|
||||
)
|
||||
|
||||
# Create new foreign key to user_file.new_id
|
||||
op.create_foreign_key(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
"user_file",
|
||||
local_cols=["user_file_id_uuid"],
|
||||
remote_cols=["new_id"],
|
||||
)
|
||||
|
||||
# Drop the old integer column and rename UUID column
|
||||
op.execute("ALTER TABLE persona__user_file DROP COLUMN IF EXISTS user_file_id")
|
||||
op.alter_column(
|
||||
"persona__user_file",
|
||||
"user_file_id_uuid",
|
||||
new_column_name="user_file_id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Recreate composite primary key
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_pkey"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file ADD PRIMARY KEY (persona_id, user_file_id)"
|
||||
)
|
||||
|
||||
logger.info("Updated persona__user_file to use UUID foreign key")
|
||||
|
||||
# === Step 2: Perform the primary key swap on user_file ===
|
||||
logger.info("Swapping user_file primary key to UUID...")
|
||||
|
||||
# Drop the primary key constraint
|
||||
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_pkey")
|
||||
|
||||
# Drop the old id column and rename new_id to id
|
||||
op.execute("ALTER TABLE user_file DROP COLUMN IF EXISTS id")
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"new_id",
|
||||
new_column_name="id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Set default for new inserts
|
||||
op.alter_column(
|
||||
"user_file",
|
||||
"id",
|
||||
existing_type=psql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
)
|
||||
|
||||
# Create new primary key
|
||||
op.execute("ALTER TABLE user_file ADD PRIMARY KEY (id)")
|
||||
|
||||
logger.info("Swapped user_file primary key to UUID")
|
||||
|
||||
# === Step 3: Update foreign key constraints ===
|
||||
logger.info("Updating foreign key constraints...")
|
||||
|
||||
# Recreate persona__user_file foreign key to point to user_file.id
|
||||
# Drop existing FK first to break dependency on the unique constraint
|
||||
op.execute(
|
||||
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
|
||||
)
|
||||
# Drop the unique constraint on (formerly) new_id BEFORE recreating the FK,
|
||||
# so the FK will bind to the primary key instead of the unique index.
|
||||
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS uq_user_file_new_id")
|
||||
# Now recreate FK to the primary key column
|
||||
op.create_foreign_key(
|
||||
"persona__user_file_user_file_id_fkey",
|
||||
"persona__user_file",
|
||||
"user_file",
|
||||
local_cols=["user_file_id"],
|
||||
remote_cols=["id"],
|
||||
)
|
||||
|
||||
# Add foreign keys for project__user_file
|
||||
existing_fks = inspector.get_foreign_keys("project__user_file")
|
||||
|
||||
has_user_file_fk = any(
|
||||
fk.get("referred_table") == "user_file"
|
||||
and fk.get("constrained_columns") == ["user_file_id"]
|
||||
for fk in existing_fks
|
||||
)
|
||||
|
||||
if not has_user_file_fk:
|
||||
op.create_foreign_key(
|
||||
"fk_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
"user_file",
|
||||
["user_file_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added project__user_file -> user_file foreign key")
|
||||
|
||||
has_project_fk = any(
|
||||
fk.get("referred_table") == "user_project"
|
||||
and fk.get("constrained_columns") == ["project_id"]
|
||||
for fk in existing_fks
|
||||
)
|
||||
|
||||
if not has_project_fk:
|
||||
op.create_foreign_key(
|
||||
"fk_project__user_file_project_id",
|
||||
"project__user_file",
|
||||
"user_project",
|
||||
["project_id"],
|
||||
["id"],
|
||||
)
|
||||
logger.info("Added project__user_file -> user_project foreign key")
|
||||
|
||||
# === Step 4: Mark files for document_id migration ===
|
||||
logger.info("Marking files for background document_id migration...")
|
||||
|
||||
logger.info("Migration 4 (UUID primary key swap) completed successfully")
|
||||
logger.info(
|
||||
"NOTE: Background task will update document IDs in Vespa and search_doc"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert UUID primary key back to integer (data destructive!)."""
|
||||
|
||||
logger.error("CRITICAL: Downgrading UUID primary key swap is data destructive!")
|
||||
logger.error(
|
||||
"This will break all UUID-based references created after the migration."
|
||||
)
|
||||
logger.error("Only proceed if absolutely necessary and have backups.")
|
||||
|
||||
# The downgrade would need to:
|
||||
# 1. Add back integer columns
|
||||
# 2. Generate new sequential IDs
|
||||
# 3. Update all foreign key references
|
||||
# 4. Swap primary keys back
|
||||
# This is complex and risky, so we raise an error instead
|
||||
|
||||
raise NotImplementedError(
|
||||
"Downgrade of UUID primary key swap is not supported due to data loss risk. "
|
||||
"Manual intervention with data backup/restore is required."
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""Migration 1: User file schema additions
|
||||
|
||||
Revision ID: 9b66d3156fc6
|
||||
Revises: b4ef3ae0bf6e
|
||||
Create Date: 2025-09-22 09:42:06.086732
|
||||
|
||||
This migration adds new columns and tables without modifying existing data.
|
||||
It is safe to run and can be easily rolled back.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as psql
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9b66d3156fc6"
|
||||
down_revision = "b4ef3ae0bf6e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add new columns and tables without modifying existing data."""
|
||||
|
||||
# Enable pgcrypto for UUID generation
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
# === USER_FILE: Add new columns ===
|
||||
logger.info("Adding new columns to user_file table...")
|
||||
|
||||
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
# Check if ID is already UUID (in case of re-run after partial migration)
|
||||
id_is_uuid = any(
|
||||
col["name"] == "id" and "uuid" in str(col["type"]).lower()
|
||||
for col in inspector.get_columns("user_file")
|
||||
)
|
||||
|
||||
# Add transitional UUID column only if ID is not already UUID
|
||||
if "new_id" not in user_file_columns and not id_is_uuid:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"new_id",
|
||||
psql.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
),
|
||||
)
|
||||
op.create_unique_constraint("uq_user_file_new_id", "user_file", ["new_id"])
|
||||
logger.info("Added new_id column to user_file")
|
||||
|
||||
# Add status column
|
||||
if "status" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"PROCESSING",
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="userfilestatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
server_default="PROCESSING",
|
||||
),
|
||||
)
|
||||
logger.info("Added status column to user_file")
|
||||
|
||||
# Add other tracking columns
|
||||
if "chunk_count" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file", sa.Column("chunk_count", sa.Integer(), nullable=True)
|
||||
)
|
||||
logger.info("Added chunk_count column to user_file")
|
||||
|
||||
if "last_accessed_at" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
logger.info("Added last_accessed_at column to user_file")
|
||||
|
||||
if "needs_project_sync" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_project_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
logger.info("Added needs_project_sync column to user_file")
|
||||
|
||||
if "last_project_sync_at" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"last_project_sync_at", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
logger.info("Added last_project_sync_at column to user_file")
|
||||
|
||||
if "document_id_migrated" not in user_file_columns:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("true"),
|
||||
),
|
||||
)
|
||||
logger.info("Added document_id_migrated column to user_file")
|
||||
|
||||
# === USER_FOLDER -> USER_PROJECT rename ===
|
||||
table_names = set(inspector.get_table_names())
|
||||
|
||||
if "user_folder" in table_names:
|
||||
logger.info("Updating user_folder table...")
|
||||
# Make description nullable first
|
||||
op.alter_column("user_folder", "description", nullable=True)
|
||||
|
||||
# Rename table if user_project doesn't exist
|
||||
if "user_project" not in table_names:
|
||||
op.execute("ALTER TABLE user_folder RENAME TO user_project")
|
||||
logger.info("Renamed user_folder to user_project")
|
||||
elif "user_project" in table_names:
|
||||
# If already renamed, ensure column nullability
|
||||
project_cols = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "description" in project_cols:
|
||||
op.alter_column("user_project", "description", nullable=True)
|
||||
|
||||
# Add instructions column to user_project
|
||||
inspector = sa.inspect(bind) # Refresh after rename
|
||||
if "user_project" in inspector.get_table_names():
|
||||
project_columns = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "instructions" not in project_columns:
|
||||
op.add_column(
|
||||
"user_project",
|
||||
sa.Column("instructions", sa.String(), nullable=True),
|
||||
)
|
||||
logger.info("Added instructions column to user_project")
|
||||
|
||||
# === CHAT_SESSION: Add project_id ===
|
||||
chat_session_columns = [
|
||||
col["name"] for col in inspector.get_columns("chat_session")
|
||||
]
|
||||
if "project_id" not in chat_session_columns:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("project_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
logger.info("Added project_id column to chat_session")
|
||||
|
||||
# === PERSONA__USER_FILE: Add UUID column ===
|
||||
persona_user_file_columns = [
|
||||
col["name"] for col in inspector.get_columns("persona__user_file")
|
||||
]
|
||||
if "user_file_id_uuid" not in persona_user_file_columns:
|
||||
op.add_column(
|
||||
"persona__user_file",
|
||||
sa.Column("user_file_id_uuid", psql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
logger.info("Added user_file_id_uuid column to persona__user_file")
|
||||
|
||||
# === PROJECT__USER_FILE: Create new table ===
|
||||
if "project__user_file" not in inspector.get_table_names():
|
||||
op.create_table(
|
||||
"project__user_file",
|
||||
sa.Column("project_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("project_id", "user_file_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_project__user_file_user_file_id",
|
||||
"project__user_file",
|
||||
["user_file_id"],
|
||||
)
|
||||
logger.info("Created project__user_file table")
|
||||
|
||||
logger.info("Migration 1 (schema additions) completed successfully")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove added columns and tables."""
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
|
||||
logger.info("Starting downgrade of schema additions...")
|
||||
|
||||
# Drop project__user_file table
|
||||
if "project__user_file" in inspector.get_table_names():
|
||||
op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
|
||||
op.drop_table("project__user_file")
|
||||
logger.info("Dropped project__user_file table")
|
||||
|
||||
# Remove columns from persona__user_file
|
||||
if "persona__user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
|
||||
if "user_file_id_uuid" in columns:
|
||||
op.drop_column("persona__user_file", "user_file_id_uuid")
|
||||
logger.info("Dropped user_file_id_uuid from persona__user_file")
|
||||
|
||||
# Remove columns from chat_session
|
||||
if "chat_session" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("chat_session")]
|
||||
if "project_id" in columns:
|
||||
op.drop_column("chat_session", "project_id")
|
||||
logger.info("Dropped project_id from chat_session")
|
||||
|
||||
# Rename user_project back to user_folder and remove instructions
|
||||
if "user_project" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_project")]
|
||||
if "instructions" in columns:
|
||||
op.drop_column("user_project", "instructions")
|
||||
op.execute("ALTER TABLE user_project RENAME TO user_folder")
|
||||
op.alter_column("user_folder", "description", nullable=False)
|
||||
logger.info("Renamed user_project back to user_folder")
|
||||
|
||||
# Remove columns from user_file
|
||||
if "user_file" in inspector.get_table_names():
|
||||
columns = [col["name"] for col in inspector.get_columns("user_file")]
|
||||
|
||||
columns_to_drop = [
|
||||
"document_id_migrated",
|
||||
"last_project_sync_at",
|
||||
"needs_project_sync",
|
||||
"last_accessed_at",
|
||||
"chunk_count",
|
||||
"status",
|
||||
]
|
||||
|
||||
for col in columns_to_drop:
|
||||
if col in columns:
|
||||
op.drop_column("user_file", col)
|
||||
logger.info(f"Dropped {col} from user_file")
|
||||
|
||||
if "new_id" in columns:
|
||||
op.drop_constraint("uq_user_file_new_id", "user_file", type_="unique")
|
||||
op.drop_column("user_file", "new_id")
|
||||
logger.info("Dropped new_id from user_file")
|
||||
|
||||
# Drop enum type if no columns use it
|
||||
bind.execute(sa.text("DROP TYPE IF EXISTS userfilestatus"))
|
||||
|
||||
logger.info("Downgrade completed successfully")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add_user_oauth_token_to_slack_bot
|
||||
|
||||
Revision ID: b4ef3ae0bf6e
|
||||
Revises: 505c488f6662
|
||||
Create Date: 2025-08-26 17:47:41.788462
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4ef3ae0bf6e"
|
||||
down_revision = "505c488f6662"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add user_token column to slack_bot table
|
||||
op.add_column("slack_bot", sa.Column("user_token", sa.LargeBinary(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user_token column from slack_bot table
|
||||
op.drop_column("slack_bot", "user_token")
|
||||
@@ -84,7 +84,34 @@ def upgrade() -> None:
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
if tool["in_code_tool_id"] in existing_tool_ids:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
|
||||
@@ -93,7 +93,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
task_logger.error(
|
||||
f"Recieved non-sync CC Pair {cc_pair.id} for external "
|
||||
f"Received non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -17,6 +17,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -170,6 +171,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names
|
||||
from ee.onyx.db.standard_answer import find_matching_standard_answers
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_messages_by_sessions
|
||||
@@ -193,7 +193,7 @@ def _handle_standard_answers(
|
||||
db_session.commit()
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
emoji=ONYX_BOT_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
|
||||
0
backend/ee/onyx/server/evals/__init__.py
Normal file
0
backend/ee/onyx/server/evals/__init__.py
Normal file
32
backend/ee/onyx/server/evals/api.py
Normal file
32
backend/ee/onyx/server/evals/api.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from onyx.background.celery.apps.client import celery_app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import User
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.server.evals.models import EvalRunAck
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/evals")
|
||||
|
||||
|
||||
@router.post("/eval_run", response_model=EvalRunAck)
|
||||
def eval_run(
|
||||
request: EvalConfigurationOptions,
|
||||
user: User = Depends(current_cloud_superuser),
|
||||
) -> EvalRunAck:
|
||||
"""
|
||||
Run an evaluation with the given message and optional dataset.
|
||||
This endpoint requires a valid API key for authentication.
|
||||
"""
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.EVAL_RUN_TASK,
|
||||
kwargs={
|
||||
"configuration_dict": request.model_dump(),
|
||||
},
|
||||
)
|
||||
return EvalRunAck(success=True)
|
||||
@@ -182,7 +182,6 @@ def admin_get_chat_sessions(
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
|
||||
@@ -37,24 +37,51 @@ def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
) -> "SentenceTransformer":
|
||||
"""
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
Build RoPE cos/sin caches once on the final device/dtype so later forwards only read.
|
||||
Works by calling the underlying HF model directly with dummy IDs/attention.
|
||||
"""
|
||||
try:
|
||||
# ensure > max seq after tokenization
|
||||
# Ideally we would use the saved tokenizer, but whatever it's ok
|
||||
# we'll make an assumption about tokenization here
|
||||
long_text = "x " * (target_len * 2)
|
||||
_ = st_model.encode(
|
||||
[long_text],
|
||||
batch_size=1,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
logger.info("RoPE pre-warm successful")
|
||||
except Exception as e:
|
||||
logger.warning(f"RoPE pre-warm skipped/failed: {e}")
|
||||
|
||||
global _GLOBAL_MODELS_DICT
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
# Some model architectures that aren't built into the Transformers or Sentence
|
||||
# Transformer need to be downloaded to be loaded locally. This does not mean
|
||||
# data is sent to remote servers for inference, however the remote code can
|
||||
# be fairly arbitrary so only use trusted models
|
||||
model = SentenceTransformer(
|
||||
model_name_or_path=model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model.max_seq_length = max_context_length
|
||||
_prewarm_rope(model, max_context_length)
|
||||
_GLOBAL_MODELS_DICT[model_name] = model
|
||||
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
|
||||
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
|
||||
else:
|
||||
model = _GLOBAL_MODELS_DICT[model_name]
|
||||
if max_context_length != model.max_seq_length:
|
||||
model.max_seq_length = max_context_length
|
||||
prev = getattr(model, "_rope_prewarmed_to", 0)
|
||||
if max_context_length > int(prev or 0):
|
||||
_prewarm_rope(model, max_context_length)
|
||||
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -10,6 +11,7 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -124,3 +126,25 @@ def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> b
|
||||
),
|
||||
)
|
||||
return _source_should_fetch_permissions_during_indexing_func(source)
|
||||
|
||||
|
||||
def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
|
||||
@@ -35,14 +35,24 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import build_citation_map_from_numbers
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from onyx.chat.stream_processing.citation_processing import (
|
||||
normalize_square_bracket_citations_to_double_with_links,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.db.chat import create_search_doc_from_saved_search_doc
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -52,12 +62,14 @@ from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -309,6 +321,52 @@ def _get_existing_clarification_request(
|
||||
return clarification, original_question, chat_history_string
|
||||
|
||||
|
||||
def _persist_final_docs_and_citations(
|
||||
db_session: Session,
|
||||
context_llm_docs: list[Any] | None,
|
||||
full_answer: str | None,
|
||||
) -> tuple[list[SearchDoc], dict[int, int] | None]:
|
||||
"""Persist final documents from in-context docs and derive citation mapping.
|
||||
|
||||
Returns the list of persisted `SearchDoc` records and an optional
|
||||
citation map translating inline [[n]] references to DB doc indices.
|
||||
"""
|
||||
final_documents_db: list[SearchDoc] = []
|
||||
citations_map: dict[int, int] | None = None
|
||||
|
||||
if not context_llm_docs:
|
||||
return final_documents_db, citations_map
|
||||
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
|
||||
for saved_doc in saved_search_docs:
|
||||
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
|
||||
db_session.add(db_doc)
|
||||
final_documents_db.append(db_doc)
|
||||
db_session.flush()
|
||||
|
||||
cited_numbers: set[int] = set()
|
||||
try:
|
||||
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
|
||||
matches = re.findall(
|
||||
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
|
||||
)
|
||||
for match in matches:
|
||||
for num_str in match.split(","):
|
||||
num = int(num_str.strip())
|
||||
cited_numbers.add(num)
|
||||
except Exception:
|
||||
cited_numbers = set()
|
||||
|
||||
if cited_numbers and final_documents_db:
|
||||
translations = build_citation_map_from_numbers(
|
||||
cited_numbers=cited_numbers,
|
||||
db_docs=final_documents_db,
|
||||
)
|
||||
citations_map = translations or None
|
||||
|
||||
return final_documents_db, citations_map
|
||||
|
||||
|
||||
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -405,21 +463,28 @@ def clarifier(
|
||||
active_source_type_descriptions_str = ""
|
||||
|
||||
if graph_config.inputs.persona:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt = PromptTemplate(
|
||||
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
|
||||
) + "\n\n"
|
||||
).build()
|
||||
if graph_config.inputs.persona.task_prompt:
|
||||
assistant_task_prompt = (
|
||||
"\n\nHere are more specifications from the user:\n\n"
|
||||
+ (graph_config.inputs.persona.task_prompt)
|
||||
+ PromptTemplate(graph_config.inputs.persona.task_prompt).build()
|
||||
)
|
||||
else:
|
||||
assistant_task_prompt = ""
|
||||
|
||||
else:
|
||||
assistant_system_prompt = DEFAULT_DR_SYSTEM_PROMPT + "\n\n"
|
||||
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
|
||||
assistant_task_prompt = ""
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
@@ -448,6 +513,11 @@ def clarifier(
|
||||
graph_config.inputs.files
|
||||
)
|
||||
|
||||
# Use project/search context docs if available to enable citation mapping
|
||||
context_llm_docs = getattr(
|
||||
graph_config.inputs.prompt_builder, "context_llm_docs", None
|
||||
)
|
||||
|
||||
if not (force_use_tool and force_use_tool.force_use):
|
||||
|
||||
if not use_tool_calling_llm or len(available_tools) == 1:
|
||||
@@ -562,10 +632,37 @@ def clarifier(
|
||||
active_source_type_descriptions_str=active_source_type_descriptions_str,
|
||||
)
|
||||
|
||||
if context_llm_docs:
|
||||
persona = graph_config.inputs.persona
|
||||
if persona is not None:
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
else:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=assistant_system_prompt,
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
|
||||
system_prompt_to_use = build_citations_system_message(
|
||||
prompt_config
|
||||
).content
|
||||
user_prompt_to_use = build_citations_user_message(
|
||||
user_query=original_question,
|
||||
files=[],
|
||||
prompt_config=prompt_config,
|
||||
context_docs=context_llm_docs,
|
||||
all_doc_useful=False,
|
||||
history_message=chat_history_string,
|
||||
context_type="user files",
|
||||
).content
|
||||
else:
|
||||
system_prompt_to_use = assistant_system_prompt
|
||||
user_prompt_to_use = decision_prompt + assistant_task_prompt
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
decision_prompt + assistant_task_prompt,
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
@@ -578,6 +675,8 @@ def clarifier(
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
@@ -585,19 +684,32 @@ def clarifier(
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = full_response.full_answer
|
||||
full_answer = (
|
||||
normalize_square_bracket_citations_to_double_with_links(
|
||||
full_response.full_answer
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
# Persist final documents and derive citations when using in-context docs
|
||||
final_documents_db, citations_map = _persist_final_docs_and_citations(
|
||||
db_session=db_session,
|
||||
context_llm_docs=context_llm_docs,
|
||||
full_answer=full_answer,
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
citations=citations_map,
|
||||
final_documents=final_documents_db or None,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
@@ -225,7 +226,7 @@ def closer(
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
@@ -349,6 +350,13 @@ def closer(
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ (graph_config.inputs.project_instructions or "")
|
||||
)
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
@@ -18,6 +19,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -56,6 +59,9 @@ def process_llm_stream(
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# Accumulate citation infos if handler emits them
|
||||
collected_citation_infos: list[CitationInfo] = []
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
@@ -102,6 +108,9 @@ def process_llm_stream(
|
||||
MessageDelta(content=response_part.answer_piece),
|
||||
writer,
|
||||
)
|
||||
# collect citation info objects
|
||||
elif isinstance(response_part, CitationInfo):
|
||||
collected_citation_infos.append(response_part)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
@@ -111,6 +120,14 @@ def process_llm_stream(
|
||||
writer,
|
||||
)
|
||||
|
||||
# Emit citations section if any were collected
|
||||
if collected_citation_infos:
|
||||
write_custom_event(ind, CitationStart(), writer)
|
||||
write_custom_event(
|
||||
ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
)
|
||||
write_custom_event(ind, SectionEnd(), writer)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -73,6 +74,7 @@ def basic_search(
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
@@ -141,6 +143,15 @@ def basic_search(
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
if force_use_tool.override_kwargs and isinstance(
|
||||
force_use_tool.override_kwargs, SearchToolOverrideKwargs
|
||||
):
|
||||
override_kwargs = force_use_tool.override_kwargs
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
project_id = override_kwargs.project_id
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
@@ -153,6 +164,8 @@ def basic_search(
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
|
||||
@@ -5,12 +5,12 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -47,7 +47,7 @@ def is_reducer(
|
||||
doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = chunks_or_sections_to_search_docs(doc_list)
|
||||
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
@@ -266,7 +266,7 @@ def convert_inference_sections_to_search_docs(
|
||||
is_internet: bool = False,
|
||||
) -> list[SavedSearchDoc]:
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = chunks_or_sections_to_search_docs(inference_sections)
|
||||
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
|
||||
for search_doc in search_docs:
|
||||
search_doc.is_internet = is_internet
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class GraphInputs(BaseModel):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
structured_response_format: dict | None = None
|
||||
project_instructions: str | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.chat.prompt_builder.schemas import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -8,18 +7,10 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
@@ -32,9 +23,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswer
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
@@ -42,25 +30,16 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DISPATCH_SEP_CHAR
|
||||
from onyx.configs.constants import FORMAT_DOCS_SEPARATOR
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -77,15 +56,12 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
@@ -156,120 +132,6 @@ def format_entity_term_extraction(
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_request: SearchRequest,
|
||||
use_agentic_search: bool = True,
|
||||
) -> GraphConfig:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else MAX_CHUNKS_FED_TO_CHAT
|
||||
),
|
||||
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
)
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
# The docs retrieved by this flow are already relevance-filtered
|
||||
all_docs_useful=True
|
||||
),
|
||||
structured_response_format=None,
|
||||
)
|
||||
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
||||
rerank_settings=None, # Can use this to change reranking model
|
||||
selected_sections=None,
|
||||
latest_query_files=None,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=get_tool_by_name(SearchTool._NAME, db_session).id,
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
document_pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
graph_inputs = GraphInputs(
|
||||
persona=search_request.persona,
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
llm_config=primary_llm.config,
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
primary_llm.config.model_provider, primary_llm.config.model_name
|
||||
)
|
||||
graph_tooling = GraphTooling(
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = (
|
||||
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
or "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=UUID(chat_session_id),
|
||||
message_id=1,
|
||||
)
|
||||
|
||||
search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
allow_refinement=True,
|
||||
)
|
||||
graph_config = GraphConfig(
|
||||
inputs=graph_inputs,
|
||||
tooling=graph_tooling,
|
||||
persistence=graph_persistence,
|
||||
behavior=search_behavior_config,
|
||||
)
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
def get_persona_agent_prompt_expressions(
|
||||
persona: Persona | None,
|
||||
) -> PersonaPromptExpressions:
|
||||
|
||||
@@ -115,7 +115,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
|
||||
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_POOL_OVERFLOW
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -83,11 +84,11 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
SqlEngine.init_engine(
|
||||
pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
@@ -316,12 +317,12 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
|
||||
113
backend/onyx/background/celery/apps/user_file_processing.py
Normal file
113
backend/onyx/background/celery/apps/user_file_processing.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.user_file_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME)
|
||||
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_concurrency = CELERY_WORKER_PRIMARY_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# User file processing worker configuration
|
||||
worker_concurrency = CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -26,6 +26,26 @@ CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "check-for-user-file-processing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
@@ -89,17 +109,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-folder-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
"schedule": timedelta(
|
||||
days=1
|
||||
), # This should essentially always be triggered manually for user folder updates.
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
|
||||
@@ -28,9 +28,6 @@ from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_userfiles_for_cc_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
@@ -484,12 +481,6 @@ def monitor_connector_deletion_taskset(
|
||||
# related to the deleted DocumentByConnectorCredentialPair during commit
|
||||
db_session.expire(cc_pair)
|
||||
|
||||
# delete all userfiles for the cc_pair
|
||||
delete_userfiles_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -85,8 +85,10 @@ from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
@@ -1268,6 +1270,8 @@ def _docprocessing_task(
|
||||
tenant_id: str,
|
||||
batch_num: int,
|
||||
) -> None:
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
if tenant_id:
|
||||
@@ -1369,6 +1373,14 @@ def _docprocessing_task(
|
||||
f"Processing {len(documents)} documents through indexing pipeline"
|
||||
)
|
||||
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
tenant_id=tenant_id,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
@@ -1378,7 +1390,8 @@ def _docprocessing_task(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
request_id=index_attempt_metadata.request_id,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
# Update batch completion and document counts atomically using database coordination
|
||||
|
||||
35
backend/onyx/background/celery/tasks/evals/tasks.py
Normal file
35
backend/onyx/background/celery/tasks/evals/tasks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EVAL_RUN_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def eval_run_task(
|
||||
self: Task,
|
||||
*,
|
||||
configuration_dict: dict[str, Any],
|
||||
) -> None:
|
||||
"""Background task to run an evaluation with the given configuration"""
|
||||
try:
|
||||
configuration = EvalConfigurationOptions.model_validate(configuration_dict)
|
||||
run_eval(configuration, remote_dataset_name=configuration.dataset_name)
|
||||
logger.info("Successfully completed eval run task")
|
||||
|
||||
except Exception:
|
||||
logger.error("Failed to run eval task")
|
||||
raise
|
||||
@@ -889,6 +889,12 @@ def monitor_celery_queues_helper(
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
n_user_file_processing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
n_user_file_project_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
@@ -916,6 +922,8 @@ def monitor_celery_queues_helper(
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
|
||||
@@ -0,0 +1,656 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
def _as_uuid(value: str | UUID) -> UUID:
|
||||
"""Return a UUID, accepting either a UUID or a string-like value."""
|
||||
return value if isinstance(value, UUID) else UUID(str(value))
|
||||
|
||||
|
||||
def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_PROCESSING_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# Do not overlap generator runs
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
UserFile.status == UserFileStatus.PROCESSING
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -> None:
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return None
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
file_names=[uf.name] if uf.name else None,
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({})
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
current_search_settings = next(
|
||||
(
|
||||
search_settings_instance
|
||||
for search_settings_instance in search_settings_list
|
||||
if search_settings_instance.status.is_current()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(batch)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set up indexing pipeline components
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
information_content_classification_model = (
|
||||
InformationContentClassificationModel()
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# update the doument id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if index_pipeline_result.failures:
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
|
||||
task_logger.info("check_for_user_file_project_sync - Starting")
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_PROJECT_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
UserFile.needs_project_sync.is_(True)
|
||||
and UserFile.status == UserFileStatus.COMPLETED
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
enqueued += 1
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str
|
||||
) -> None:
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
chunks_affected = retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
|
||||
user_prefix = "USER_FILE_CONNECTOR__"
|
||||
file_prefix = "FILE_CONNECTOR__"
|
||||
if old_id.startswith(user_prefix):
|
||||
remainder = old_id[len(user_prefix) :]
|
||||
return file_prefix + remainder
|
||||
return old_id
|
||||
|
||||
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
index_name: str,
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "1000",
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
resp = http_client.get(base_url, params=params, timeout=None)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
|
||||
|
||||
def _update_document_id_in_vespa(
|
||||
*,
|
||||
index_name: str,
|
||||
old_doc_id: str,
|
||||
new_doc_id: str,
|
||||
user_project_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
|
||||
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
|
||||
|
||||
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
|
||||
task_logger.debug(f"Vespa selection: {selection}")
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
if not docs:
|
||||
break
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
continue
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
update_request: dict[str, Any] = {
|
||||
"fields": {"document_id": {"assign": clean_new_doc_id}}
|
||||
}
|
||||
if user_project_ids is not None:
|
||||
update_request["fields"][USER_PROJECT] = {
|
||||
"assign": user_project_ids
|
||||
}
|
||||
r = http_client.put(vespa_url, json=update_request)
|
||||
r.raise_for_status()
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
|
||||
|
||||
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
|
||||
- Update `search_doc.document_id` to the same UUID string.
|
||||
"""
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
active_settings.primary,
|
||||
active_settings.secondary,
|
||||
)
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
index_name = "danswer_index"
|
||||
|
||||
# Fetch mappings of legacy -> new ids
|
||||
rows = db_session.execute(
|
||||
sa.select(
|
||||
UserFile.document_id.label("document_id"),
|
||||
UserFile.id.label("id"),
|
||||
).where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# dedupe by old document_id
|
||||
seen: set[str] = set()
|
||||
for row in rows:
|
||||
old_doc_id = str(row.document_id)
|
||||
new_uuid = str(row.id)
|
||||
if not old_doc_id or not new_uuid or old_doc_id in seen:
|
||||
continue
|
||||
seen.add(old_doc_id)
|
||||
# collect user project ids for a combined Vespa update
|
||||
user_project_ids: list[int] | None = None
|
||||
try:
|
||||
uf = db_session.get(UserFile, UUID(new_uuid))
|
||||
if uf is not None:
|
||||
user_project_ids = [project.id for project in uf.projects]
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
try:
|
||||
_update_document_id_in_vespa(
|
||||
index_name=index_name,
|
||||
old_doc_id=old_doc_id,
|
||||
new_doc_id=new_uuid,
|
||||
user_project_ids=user_project_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Update search_doc records to refer to the UUID string
|
||||
uf_id_subq = (
|
||||
sa.select(sa.cast(UserFile.id, sa.String))
|
||||
.where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
SearchDoc.document_id == UserFile.document_id,
|
||||
)
|
||||
.correlate(SearchDoc)
|
||||
.scalar_subquery()
|
||||
)
|
||||
db_session.execute(
|
||||
sa.update(SearchDoc)
|
||||
.where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1)).where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
SearchDoc.document_id == UserFile.document_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
.values(document_id=uf_id_subq)
|
||||
)
|
||||
# Mark all processed user_files as migrated
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
.values(document_id_migrated=True)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
|
||||
try:
|
||||
store = get_default_file_store()
|
||||
# Only supported for S3-backed stores where we can manipulate object keys
|
||||
if isinstance(store, S3BackedFileStore):
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.file_id.like("plaintext_%"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(
|
||||
Bucket=fr.bucket_name, Key=fr.object_key
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed plaintext object normalize for "
|
||||
f"id={fr.file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
@@ -414,8 +414,14 @@ def monitor_document_set_taskset(
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
has_connector_pairs = bool(document_set.connector_credential_pairs)
|
||||
# Federated connectors should keep a document set alive even without cc pairs.
|
||||
has_federated_connectors = bool(
|
||||
getattr(document_set, "federated_connectors", [])
|
||||
)
|
||||
|
||||
if not has_connector_pairs and not has_federated_connectors:
|
||||
# If there are no connectors of any kind, delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set: document_set={document_set_id}"
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Factory stub for running the user file processing Celery worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.user_file_processing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -28,7 +28,6 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
@@ -64,9 +63,11 @@ from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
@@ -100,6 +101,8 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
try:
|
||||
@@ -283,6 +286,8 @@ def _run_indexing(
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
@@ -567,6 +572,13 @@ def _run_indexing(
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
db_session=db_session,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
@@ -578,7 +590,8 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
request_id=index_attempt_md.request_id,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
|
||||
@@ -62,6 +62,7 @@ class Answer:
|
||||
use_agentic_search: bool = False,
|
||||
research_type: ResearchType | None = None,
|
||||
research_plan: dict[str, Any] | None = None,
|
||||
project_instructions: str | None = None,
|
||||
) -> None:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
self._processed_stream: list[AnswerStreamPart] | None = None
|
||||
@@ -97,6 +98,7 @@ class Answer:
|
||||
prompt_builder=prompt_builder,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
self.graph_tooling = GraphTooling(
|
||||
primary_llm=llm,
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -40,7 +41,9 @@ from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
)
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -63,6 +66,9 @@ def prepare_chat_message_request(
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -88,6 +94,9 @@ def prepare_chat_message_request(
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
slack_context=slack_context, # Pass Slack context
|
||||
)
|
||||
|
||||
|
||||
@@ -335,6 +344,45 @@ def reorganize_citations(
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
||||
|
||||
def build_citation_map_from_infos(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
"""Translate a list of streaming CitationInfo objects into a mapping of
|
||||
citation number -> saved search doc DB id.
|
||||
|
||||
Always cites the first instance of a document_id and assumes db_docs are
|
||||
ordered as shown to the user (display order).
|
||||
"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
for db_doc in db_docs:
|
||||
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
||||
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
||||
|
||||
citation_to_saved_doc_id_map: dict[int, int] = {}
|
||||
for citation in citations_list:
|
||||
if citation.citation_num not in citation_to_saved_doc_id_map:
|
||||
saved_id = doc_id_to_saved_doc_id_map.get(citation.document_id)
|
||||
if saved_id is not None:
|
||||
citation_to_saved_doc_id_map[citation.citation_num] = saved_id
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
def build_citation_map_from_numbers(
|
||||
cited_numbers: list[int] | set[int], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
"""Translate parsed citation numbers (e.g., from [[n]]) into a mapping of
|
||||
citation number -> saved search doc DB id by positional index.
|
||||
"""
|
||||
citation_to_saved_doc_id_map: dict[int, int] = {}
|
||||
for num in sorted(set(cited_numbers)):
|
||||
idx = num - 1
|
||||
if 0 <= idx < len(db_docs):
|
||||
citation_to_saved_doc_id_map[num] = db_docs[idx].id
|
||||
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
def extract_headers(
|
||||
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
|
||||
) -> dict[str, str]:
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -18,6 +19,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import PromptConfig
|
||||
@@ -35,6 +37,7 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
@@ -63,9 +66,13 @@ from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.projects import get_project_instructions
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
@@ -101,6 +108,7 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
@@ -119,6 +127,55 @@ class PartialResponse(Protocol):
|
||||
) -> ChatMessage: ...
|
||||
|
||||
|
||||
def _build_project_llm_docs(
|
||||
project_file_ids: list[str] | None,
|
||||
in_memory_user_files: list[InMemoryChatFile] | None,
|
||||
) -> list[LlmDoc]:
|
||||
"""Construct `LlmDoc` objects for project-scoped user files for citation flow."""
|
||||
project_llm_docs: list[LlmDoc] = []
|
||||
if not project_file_ids or not in_memory_user_files:
|
||||
return project_llm_docs
|
||||
|
||||
project_file_id_set = set(project_file_ids)
|
||||
for f in in_memory_user_files:
|
||||
# Only include files that belong to the project (not ad-hoc uploads)
|
||||
if project_file_id_set and (f.file_id in project_file_id_set):
|
||||
try:
|
||||
text_content = f.content.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
text_content = ""
|
||||
|
||||
# Build a short blurb from the file content for better UI display
|
||||
blurb = (
|
||||
(text_content[:200] + "...")
|
||||
if len(text_content) > 200
|
||||
else text_content
|
||||
)
|
||||
|
||||
# Provide basic metadata to improve SavedSearchDoc display
|
||||
file_metadata: dict[str, str | list[str]] = {
|
||||
"filename": f.filename or str(f.file_id),
|
||||
"file_type": f.file_type.value,
|
||||
}
|
||||
|
||||
project_llm_docs.append(
|
||||
LlmDoc(
|
||||
document_id=str(f.file_id),
|
||||
content=text_content,
|
||||
blurb=blurb,
|
||||
semantic_identifier=f.filename or str(f.file_id),
|
||||
source_type=DocumentSource.USER_FILE,
|
||||
metadata=file_metadata,
|
||||
updated_at=None,
|
||||
link=build_frontend_file_url(str(f.file_id)),
|
||||
source_links=None,
|
||||
match_highlights=None,
|
||||
)
|
||||
)
|
||||
|
||||
return project_llm_docs
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> MessageSpecificCitations:
|
||||
@@ -436,26 +493,28 @@ def stream_chat_message_objects(
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
user_folder_ids = new_msg_req.user_folder_ids or []
|
||||
user_file_ids: list[UUID] = []
|
||||
|
||||
if persona.user_files:
|
||||
for file in persona.user_files:
|
||||
user_file_ids.append(file.id)
|
||||
if persona.user_folders:
|
||||
for folder in persona.user_folders:
|
||||
user_folder_ids.append(folder.id)
|
||||
for uf in persona.user_files:
|
||||
user_file_ids.append(uf.id)
|
||||
|
||||
if new_msg_req.current_message_files:
|
||||
for fd in new_msg_req.current_message_files:
|
||||
uid = fd.get("user_file_id")
|
||||
if uid is not None:
|
||||
user_file_ids.append(uid)
|
||||
|
||||
# Load in user files into memory and create search tool override kwargs if needed
|
||||
# if we have enough tokens and no folders, we don't need to use search
|
||||
# if we have enough tokens, we don't need to use search
|
||||
# we can just pass them into the prompt directly
|
||||
(
|
||||
in_memory_user_files,
|
||||
user_file_models,
|
||||
search_tool_override_kwargs_for_user_files,
|
||||
) = parse_user_files(
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
user_file_ids=user_file_ids or [],
|
||||
project_id=chat_session.project_id,
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=message_text,
|
||||
@@ -464,16 +523,37 @@ def stream_chat_message_objects(
|
||||
if not search_tool_override_kwargs_for_user_files:
|
||||
latest_query_files.extend(in_memory_user_files)
|
||||
|
||||
project_file_ids = []
|
||||
if chat_session.project_id:
|
||||
project_file_ids.extend(
|
||||
[
|
||||
file.file_id
|
||||
for file in get_user_files_from_project(
|
||||
chat_session.project_id, user_id, db_session
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# we don't want to attach project files to the user message
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
chat_message=user_message,
|
||||
files=[
|
||||
new_file.to_file_descriptor() for new_file in latest_query_files
|
||||
new_file.to_file_descriptor()
|
||||
for new_file in latest_query_files
|
||||
if project_file_ids is not None
|
||||
and (new_file.file_id not in project_file_ids)
|
||||
],
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Build project context docs for citation flow if project files are present
|
||||
project_llm_docs: list[LlmDoc] = _build_project_llm_docs(
|
||||
project_file_ids=project_file_ids,
|
||||
in_memory_user_files=in_memory_user_files,
|
||||
)
|
||||
|
||||
selected_db_search_docs = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
if reference_doc_ids:
|
||||
@@ -559,12 +639,22 @@ def stream_chat_message_objects(
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona)
|
||||
|
||||
# Retrieve project-specific instructions if this chat session is associated with a project.
|
||||
project_instructions: str | None = (
|
||||
get_project_instructions(
|
||||
db_session=db_session, project_id=chat_session.project_id
|
||||
)
|
||||
if persona.is_default_persona
|
||||
else None
|
||||
) # if the persona is not default, we don't want to use the project instructions
|
||||
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
has_project_files = project_file_ids is not None and len(project_file_ids) > 0
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -574,9 +664,17 @@ def stream_chat_message_objects(
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
run_search_setting=(
|
||||
retrieval_options.run_search
|
||||
if retrieval_options
|
||||
else OptionalSearchSetting.AUTO
|
||||
OptionalSearchSetting.NEVER
|
||||
if (
|
||||
chat_session.project_id
|
||||
and not has_project_files
|
||||
and persona.is_default_persona
|
||||
)
|
||||
else (
|
||||
retrieval_options.run_search
|
||||
if retrieval_options
|
||||
else OptionalSearchSetting.AUTO
|
||||
)
|
||||
),
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -603,6 +701,7 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
slack_context=new_msg_req.slack_context, # Pass Slack context from request
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
@@ -617,6 +716,7 @@ def stream_chat_message_objects(
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
|
||||
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
@@ -624,6 +724,8 @@ def stream_chat_message_objects(
|
||||
id=str(file.file_id), type=file.file_type, name=file.filename
|
||||
)
|
||||
for file in in_memory_user_files
|
||||
if project_file_ids is not None
|
||||
and (file.file_id not in project_file_ids)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -642,6 +744,10 @@ def stream_chat_message_objects(
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
|
||||
if project_llm_docs and not search_tool_override_kwargs_for_user_files:
|
||||
# Store for downstream streaming to wire citations and final_documents
|
||||
prompt_builder.context_llm_docs = project_llm_docs
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -670,6 +776,7 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
|
||||
# Process streamed packets using the new packet processing module
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
@@ -76,6 +76,7 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
@@ -132,6 +133,10 @@ class AnswerPromptBuilder:
|
||||
self.raw_user_uploaded_files = raw_user_uploaded_files
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
# Optional: if the prompt includes explicit context documents (e.g., project files),
|
||||
# store them here so downstream streaming can reference them for citation mapping.
|
||||
self.context_llm_docs: list[LlmDoc] | None = None
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
self.system_message_and_token_cnt = None
|
||||
@@ -196,10 +201,6 @@ class AnswerPromptBuilder:
|
||||
|
||||
|
||||
# Stores some parts of a prompt builder as needed for tool calls
|
||||
class PromptSnapshot(BaseModel):
|
||||
raw_message_history: list[PreviousMessage]
|
||||
raw_user_query: str
|
||||
built_prompt: list[BaseMessage]
|
||||
|
||||
|
||||
# TODO: rename this? AnswerConfig maybe?
|
||||
|
||||
10
backend/onyx/chat/prompt_builder/schemas.py
Normal file
10
backend/onyx/chat/prompt_builder/schemas.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.models import PreviousMessage
|
||||
|
||||
|
||||
class PromptSnapshot(BaseModel):
|
||||
raw_message_history: list[PreviousMessage]
|
||||
raw_user_query: str
|
||||
built_prompt: list[BaseMessage]
|
||||
@@ -12,6 +12,35 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def normalize_square_bracket_citations_to_double_with_links(text: str) -> str:
|
||||
"""
|
||||
Normalize citation markers in the text:
|
||||
- Convert bare double-bracket citations without links `[[n]]` to `[[n]]()`
|
||||
- Convert single-bracket citations `[n]` to `[[n]]()`
|
||||
Leaves existing linked citations like `[[n]](http...)` unchanged.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Add empty parens to bare double-bracket citations without a link: [[n]] -> [[n]]()
|
||||
pattern_double_no_link = re.compile(r"\[\[(\d+)\]\](?!\()")
|
||||
|
||||
def _repl_double(match: re.Match[str]) -> str:
|
||||
num = match.group(1)
|
||||
return f"[[{num}]]()"
|
||||
|
||||
text = pattern_double_no_link.sub(_repl_double, text)
|
||||
|
||||
# Convert single [n] not already [[n]] to [[n]]()
|
||||
pattern_single = re.compile(r"(?<!\[)\[(\d+)\](?!\])")
|
||||
|
||||
def _repl_single(match: re.Match[str]) -> str:
|
||||
num = match.group(1)
|
||||
return f"[[{num}]]()"
|
||||
|
||||
return pattern_single.sub(_repl_single, text)
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain_core.messages import ToolCall
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.chat.prompt_builder.schemas import PromptSnapshot
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.message import build_tool_message
|
||||
|
||||
@@ -4,6 +4,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.user_file import update_last_accessed_at_for_user_files
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import get_user_files_as_user
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
@@ -15,24 +17,24 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def parse_user_files(
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str,
|
||||
project_id: int | None,
|
||||
# should only be None if auth is disabled
|
||||
user_id: UUID | None,
|
||||
) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]:
|
||||
"""
|
||||
Parse user files and folders into in-memory chat files and create search tool override kwargs.
|
||||
Only creates SearchToolOverrideKwargs if token overflow occurs or folders are present.
|
||||
Parse user files and project into in-memory chat files and create search tool override kwargs.
|
||||
Only creates SearchToolOverrideKwargs if token overflow occurs.
|
||||
|
||||
Args:
|
||||
user_file_ids: List of user file IDs to load
|
||||
user_folder_ids: List of user folder IDs to load
|
||||
db_session: Database session
|
||||
persona: Persona to calculate available tokens
|
||||
actual_user_input: User's input message for token calculation
|
||||
project_id: Project ID to validate file ownership
|
||||
user_id: User ID to validate file ownership
|
||||
|
||||
Returns:
|
||||
@@ -40,37 +42,56 @@ def parse_user_files(
|
||||
loaded user files,
|
||||
user file models,
|
||||
search tool override kwargs if token
|
||||
overflow or folders present
|
||||
overflow
|
||||
)
|
||||
"""
|
||||
# Return empty results if no files or folders specified
|
||||
if not user_file_ids and not user_folder_ids:
|
||||
# Return empty results if no files or project specified
|
||||
if not user_file_ids and not project_id:
|
||||
return [], [], None
|
||||
|
||||
project_user_file_ids = []
|
||||
|
||||
if project_id:
|
||||
project_user_file_ids.extend(
|
||||
[
|
||||
file.id
|
||||
for file in get_user_files_from_project(project_id, user_id, db_session)
|
||||
]
|
||||
)
|
||||
|
||||
# Combine user-provided and project-derived user file IDs
|
||||
combined_user_file_ids = user_file_ids + project_user_file_ids or []
|
||||
|
||||
# Load user files from the database into memory
|
||||
user_files = load_in_memory_chat_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
combined_user_file_ids,
|
||||
db_session,
|
||||
)
|
||||
|
||||
user_file_models = get_user_files_as_user(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
combined_user_file_ids,
|
||||
user_id,
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Update last accessed at for the user files which are used in the chat
|
||||
if user_file_ids or project_user_file_ids:
|
||||
# update_last_accessed_at_for_user_files expects list[UUID]
|
||||
update_last_accessed_at_for_user_files(
|
||||
combined_user_file_ids,
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate token count for the files, need to import here to avoid circular import
|
||||
# TODO: fix this
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
|
||||
# calculate_user_files_token_count now expects list[UUID]
|
||||
total_tokens = calculate_user_files_token_count(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
combined_user_file_ids,
|
||||
db_session,
|
||||
)
|
||||
|
||||
@@ -86,20 +107,22 @@ def parse_user_files(
|
||||
|
||||
have_enough_tokens = total_tokens <= available_tokens
|
||||
|
||||
# If we have enough tokens and no folders, we don't need search
|
||||
# If we have enough tokens, we don't need search
|
||||
# we can just pass them into the prompt directly
|
||||
if have_enough_tokens and not user_folder_ids:
|
||||
if have_enough_tokens:
|
||||
# No search tool override needed - files can be passed directly
|
||||
return user_files, user_file_models, None
|
||||
|
||||
# Token overflow or folders present - need to use search tool
|
||||
# Token overflow - need to use search tool
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=have_enough_tokens,
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=have_enough_tokens,
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
user_file_ids=user_file_ids or [],
|
||||
project_id=(
|
||||
project_id if persona.is_default_persona else None
|
||||
), # if the persona is not default, we don't want to use the project files
|
||||
)
|
||||
|
||||
return user_files, user_file_models, override_kwargs
|
||||
|
||||
@@ -65,19 +65,19 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 12))
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
PASSWORD_MAX_LENGTH = int(os.getenv("PASSWORD_MAX_LENGTH", 64))
|
||||
PASSWORD_REQUIRE_UPPERCASE = (
|
||||
os.environ.get("PASSWORD_REQUIRE_UPPERCASE", "true").lower() == "true"
|
||||
os.environ.get("PASSWORD_REQUIRE_UPPERCASE", "false").lower() == "true"
|
||||
)
|
||||
PASSWORD_REQUIRE_LOWERCASE = (
|
||||
os.environ.get("PASSWORD_REQUIRE_LOWERCASE", "true").lower() == "true"
|
||||
os.environ.get("PASSWORD_REQUIRE_LOWERCASE", "false").lower() == "true"
|
||||
)
|
||||
PASSWORD_REQUIRE_DIGIT = (
|
||||
os.environ.get("PASSWORD_REQUIRE_DIGIT", "true").lower() == "true"
|
||||
os.environ.get("PASSWORD_REQUIRE_DIGIT", "false").lower() == "true"
|
||||
)
|
||||
PASSWORD_REQUIRE_SPECIAL_CHAR = (
|
||||
os.environ.get("PASSWORD_REQUIRE_SPECIAL_CHAR", "true").lower() == "true"
|
||||
os.environ.get("PASSWORD_REQUIRE_SPECIAL_CHAR", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
@@ -355,6 +355,26 @@ CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
|
||||
try:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 8192
|
||||
|
||||
@@ -658,8 +678,8 @@ LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
# Logs Onyx only model interactions like prompts, responses, messages etc.
|
||||
LOG_DANSWER_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
LOG_ONYX_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
LOG_INDIVIDUAL_MODEL_TOKENS = (
|
||||
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
|
||||
@@ -677,6 +697,17 @@ LOG_POSTGRES_CONN_COUNTS = (
|
||||
# Anonymous usage telemetry
|
||||
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# Braintrust Configuration
|
||||
#####
|
||||
# Enable Braintrust tracing for LangGraph/LangChain applications
|
||||
BRAINTRUST_ENABLED = os.environ.get("BRAINTRUST_ENABLED", "").lower() == "true"
|
||||
# Braintrust project name
|
||||
BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx")
|
||||
BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or ""
|
||||
# Maximum concurrency for Braintrust evaluations
|
||||
BRAINTRUST_MAX_CONCURRENCY = int(os.environ.get("BRAINTRUST_MAX_CONCURRENCY") or 5)
|
||||
|
||||
TOKEN_BUDGET_GLOBALLY_ENABLED = (
|
||||
os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
|
||||
@@ -7,6 +7,8 @@ from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
@@ -76,6 +78,9 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
"celery_worker_user_file_processing"
|
||||
)
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
@@ -112,7 +117,6 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
@@ -203,6 +207,8 @@ class DocumentSource(str, Enum):
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
# Special case for user files
|
||||
USER_FILE = "user_file"
|
||||
|
||||
|
||||
class FederatedConnectorSource(str, Enum):
|
||||
@@ -298,6 +304,7 @@ class FileOrigin(str, Enum):
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
QUERY_HISTORY_CSV = "query_history_csv"
|
||||
USER_FILE = "user_file"
|
||||
|
||||
|
||||
class FileType(str, Enum):
|
||||
@@ -343,6 +350,9 @@ class OnyxCeleryQueues:
|
||||
# Indexing queue
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# User file processing queue
|
||||
USER_FILE_PROCESSING = "user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
|
||||
# Document processing pipeline queue
|
||||
DOCPROCESSING = "docprocessing"
|
||||
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
|
||||
@@ -368,7 +378,7 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
CLOUD_PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
@@ -390,6 +400,12 @@ class OnyxRedisLocks:
|
||||
# KG processing
|
||||
KG_PROCESSING_LOCK = "da_lock:kg_processing"
|
||||
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -448,8 +464,6 @@ class OnyxCeleryTask:
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_pidbox"
|
||||
)
|
||||
|
||||
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -457,7 +471,12 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
|
||||
|
||||
# User file processing
|
||||
CHECK_FOR_USER_FILE_PROCESSING = "check_for_user_file_processing"
|
||||
PROCESS_SINGLE_USER_FILE = "process_single_user_file"
|
||||
CHECK_FOR_USER_FILE_PROJECT_SYNC = "check_for_user_file_project_sync"
|
||||
PROCESS_SINGLE_USER_FILE_PROJECT_SYNC = "process_single_user_file_project_sync"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
@@ -490,6 +509,7 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
|
||||
|
||||
# chat retention
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
@@ -497,6 +517,8 @@ class OnyxCeleryTask:
|
||||
|
||||
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
|
||||
|
||||
EVAL_RUN_TASK = "eval_run_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
|
||||
@@ -3,28 +3,26 @@ import os
|
||||
#####
|
||||
# Onyx Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5"))
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
)
|
||||
ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5"))
|
||||
# If the LLM fails to answer, Onyx can still show the "Reference Documents"
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||
ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
|
||||
"ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
|
||||
).lower() not in ["false", ""]
|
||||
# When Onyx is considering a message, what emoji does it react with
|
||||
DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes"
|
||||
ONYX_BOT_REACT_EMOJI = os.environ.get("ONYX_BOT_REACT_EMOJI") or "eyes"
|
||||
# When User needs more help, what should the emoji be
|
||||
DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos"
|
||||
ONYX_BOT_FOLLOWUP_EMOJI = os.environ.get("ONYX_BOT_FOLLOWUP_EMOJI") or "sos"
|
||||
# What kind of message should be shown when someone gives an AI answer feedback to OnyxBot
|
||||
# Defaults to Private if not provided or invalid
|
||||
# Private: Only visible to user clicking the feedback
|
||||
# Anonymous: Public but anonymous
|
||||
# Public: Visible with the user name who submitted the feedback
|
||||
DANSWER_BOT_FEEDBACK_VISIBILITY = (
|
||||
os.environ.get("DANSWER_BOT_FEEDBACK_VISIBILITY") or "private"
|
||||
ONYX_BOT_FEEDBACK_VISIBILITY = (
|
||||
os.environ.get("ONYX_BOT_FEEDBACK_VISIBILITY") or "private"
|
||||
)
|
||||
# Should OnyxBot send an apology message if it's not able to find an answer
|
||||
# That way the user isn't confused as to why OnyxBot reacted but then said nothing
|
||||
@@ -34,40 +32,38 @@ NOTIFY_SLACKBOT_NO_ANSWER = (
|
||||
)
|
||||
# Mostly for debugging purposes but it's for explaining what went wrong
|
||||
# if OnyxBot couldn't find an answer
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
|
||||
ONYX_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
"ONYX_BOT_DISPLAY_ERROR_MSGS", ""
|
||||
).lower() not in [
|
||||
"false",
|
||||
"",
|
||||
]
|
||||
# Default is only respond in channels that are included by a slack config set in the UI
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
ONYX_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("ONYX_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
ONYX_BOT_MAX_QPM = int(os.environ.get("ONYX_BOT_MAX_QPM") or 0) or None
|
||||
# Maximum time to wait when a question is queued
|
||||
DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME") or 180)
|
||||
ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180)
|
||||
|
||||
# Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback.
|
||||
# Set to 0 to disable it (default)
|
||||
DANSWER_BOT_FEEDBACK_REMINDER = int(
|
||||
os.environ.get("DANSWER_BOT_FEEDBACK_REMINDER") or 0
|
||||
)
|
||||
ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0)
|
||||
# Set to True to rephrase the Slack users messages
|
||||
DANSWER_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
ONYX_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("ONYX_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses OnyxBot can send in a given time period.
|
||||
# Set to 0 to disable the limit.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
|
||||
ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
|
||||
os.environ.get("ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
|
||||
)
|
||||
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
|
||||
# ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
|
||||
# of seconds until the response limit is reset.
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
|
||||
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
|
||||
ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
|
||||
os.environ.get("ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ def _create_image_section(
|
||||
image_data: bytes,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
media_type: str | None = None,
|
||||
link: str | None = None,
|
||||
idx: int = 0,
|
||||
) -> tuple[ImageSection, str | None]:
|
||||
@@ -58,6 +59,9 @@ def _create_image_section(
|
||||
image_data=image_data,
|
||||
file_id=file_id,
|
||||
display_name=display_name,
|
||||
media_type=(
|
||||
media_type if media_type is not None else "application/octet-stream"
|
||||
),
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
@@ -123,6 +127,7 @@ def _process_file(
|
||||
image_data=image_data,
|
||||
parent_file_name=file_id,
|
||||
display_name=title,
|
||||
media_type=file_type,
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -194,6 +199,7 @@ def _process_file(
|
||||
image_data=img_data,
|
||||
parent_file_name=file_id,
|
||||
display_name=f"{title} - image {idx}",
|
||||
media_type="application/octet-stream", # Default media type for embedded images
|
||||
idx=idx,
|
||||
)
|
||||
sections.append(image_section)
|
||||
|
||||
@@ -168,8 +168,8 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
recieved_transcripts = response.json()
|
||||
parsed_transcripts = recieved_transcripts.get("data", {}).get(
|
||||
received_transcripts = response.json()
|
||||
parsed_transcripts = received_transcripts.get("data", {}).get(
|
||||
"transcripts", []
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +61,19 @@ EMAIL_FIELDS = [
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _is_mail_service_disabled_error(error: HttpError) -> bool:
|
||||
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
|
||||
|
||||
if error.resp.status != 400:
|
||||
return False
|
||||
|
||||
error_message = str(error)
|
||||
return (
|
||||
"Mail service not enabled" in error_message
|
||||
or "failedPrecondition" in error_message
|
||||
)
|
||||
|
||||
|
||||
def _build_time_range_query(
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -307,33 +320,42 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_batch = []
|
||||
for user_email in self._get_all_user_emails():
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
# full_threads is an iterator containing a single thread
|
||||
# so we need to convert it to a list and grab the first element
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
):
|
||||
full_threads = execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
# full_threads is an iterator containing a single thread
|
||||
# so we need to convert it to a list and grab the first element
|
||||
full_thread = list(full_threads)[0]
|
||||
doc = thread_to_document(full_thread, user_email)
|
||||
if doc is None:
|
||||
continue
|
||||
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
except HttpError as e:
|
||||
if _is_mail_service_disabled_error(e):
|
||||
logger.warning(
|
||||
"Skipping Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -349,35 +371,44 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for user_email in self._get_all_user_emails():
|
||||
logger.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
id=thread["id"],
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
try:
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
id=thread["id"],
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails={user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
except HttpError as e:
|
||||
if _is_mail_service_disabled_error(e):
|
||||
logger.warning(
|
||||
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
|
||||
user_email,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@@ -57,7 +57,9 @@ from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_a
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -141,6 +143,10 @@ class SharepointAuthMethod(Enum):
|
||||
CERTIFICATE = "certificate"
|
||||
|
||||
|
||||
class SizeCapExceeded(Exception):
|
||||
"""Exception raised when the size cap is exceeded."""
|
||||
|
||||
|
||||
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
|
||||
"""Load certificate from .pfx file for MSAL authentication"""
|
||||
try:
|
||||
@@ -239,7 +245,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
Behavior:
|
||||
- Checks `Content-Length` first and aborts early if it exceeds `cap`.
|
||||
- Otherwise streams the body in chunks and stops once `cap` is surpassed.
|
||||
- Raises `RuntimeError('size_cap_exceeded')` when the cap would be exceeded.
|
||||
- Raises `SizeCapExceeded` when the cap would be exceeded.
|
||||
- Returns the full bytes if the content fits within `cap`.
|
||||
"""
|
||||
with requests.get(url, stream=True, timeout=timeout) as resp:
|
||||
@@ -253,7 +259,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
logger.warning(
|
||||
f"Content-Length {content_len} exceeds cap {cap}; skipping download."
|
||||
)
|
||||
raise RuntimeError("size_cap_exceeded")
|
||||
raise SizeCapExceeded("pre_download")
|
||||
|
||||
buf = io.BytesIO()
|
||||
# Stream in 64KB chunks; adjust if needed for slower networks.
|
||||
@@ -266,11 +272,32 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
logger.warning(
|
||||
f"Streaming download exceeded cap {cap} bytes; aborting early."
|
||||
)
|
||||
raise RuntimeError("size_cap_exceeded")
|
||||
raise SizeCapExceeded("during_download")
|
||||
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _download_via_sdk_with_cap(
|
||||
driveitem: DriveItem, bytes_allowed: int, chunk_size: int = 64 * 1024
|
||||
) -> bytes:
|
||||
"""Use the Office365 SDK streaming download with a hard byte cap.
|
||||
|
||||
Raises SizeCapExceeded("during_sdk_download") if the cap would be exceeded.
|
||||
"""
|
||||
buf = io.BytesIO()
|
||||
|
||||
def on_chunk(bytes_read: int) -> None:
|
||||
# bytes_read is total bytes seen so far per SDK contract
|
||||
if bytes_read > bytes_allowed:
|
||||
raise SizeCapExceeded("during_sdk_download")
|
||||
|
||||
# modifies the driveitem to change its download behavior
|
||||
driveitem.download_session(buf, chunk_downloaded=on_chunk, chunk_size=chunk_size)
|
||||
# Execute the configured request with retries using existing helper
|
||||
sleep_and_retry(driveitem.context, "download_session")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _convert_driveitem_to_document_with_permissions(
|
||||
driveitem: DriveItem,
|
||||
drive_name: str,
|
||||
@@ -289,6 +316,16 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
file_size: int | None = None
|
||||
try:
|
||||
item_json = driveitem.to_json()
|
||||
mime_type = item_json.get("file", {}).get("mimeType")
|
||||
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
|
||||
# for now, this skip must happen before we download the file
|
||||
# Similar to Google Drive, we'll just semi-silently skip excluded image types
|
||||
logger.debug(
|
||||
f"Skipping malformed or excluded mime type {mime_type} for {driveitem.name}"
|
||||
)
|
||||
return None
|
||||
|
||||
size_value = item_json.get("size")
|
||||
if size_value is not None:
|
||||
file_size = int(size_value)
|
||||
@@ -311,19 +348,16 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
content_bytes: bytes | None = None
|
||||
if download_url:
|
||||
try:
|
||||
# Use this to test the sdk size cap
|
||||
# raise requests.RequestException("test")
|
||||
content_bytes = _download_with_cap(
|
||||
download_url,
|
||||
REQUEST_TIMEOUT_SECONDS,
|
||||
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "size_cap_exceeded" in str(e):
|
||||
logger.warning(
|
||||
f"Skipping '{driveitem.name}' exceeded size cap during streaming."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise
|
||||
except SizeCapExceeded as e:
|
||||
logger.warning(f"Skipping '{driveitem.name}' exceeded size cap: {str(e)}")
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
status = e.response.status_code if e.response is not None else -1
|
||||
logger.warning(
|
||||
@@ -332,13 +366,15 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
|
||||
# Fallback to SDK content if needed
|
||||
if content_bytes is None:
|
||||
content = sleep_and_retry(driveitem.get_content(), "get_content")
|
||||
if content is None or not isinstance(
|
||||
getattr(content, "value", None), (bytes, bytearray)
|
||||
):
|
||||
logger.warning(f"Could not access content for '{driveitem.name}'")
|
||||
raise ValueError(f"Could not access content for '{driveitem.name}'")
|
||||
content_bytes = bytes(content.value)
|
||||
try:
|
||||
content_bytes = _download_via_sdk_with_cap(
|
||||
driveitem, SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
|
||||
)
|
||||
except SizeCapExceeded:
|
||||
logger.warning(
|
||||
f"Skipping '{driveitem.name}' exceeded size cap during SDK streaming."
|
||||
)
|
||||
return None
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
file_ext = driveitem.name.split(".")[-1]
|
||||
@@ -348,6 +384,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
|
||||
)
|
||||
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
# NOTE: this if should use is_valid_image_type instead with mime_type
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=content_bytes,
|
||||
file_id=driveitem.id,
|
||||
@@ -358,23 +395,45 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
sections.append(image_section)
|
||||
else:
|
||||
# Note: we don't process Onyx metadata for connectors like Drive & Sharepoint, but could
|
||||
def _store_embedded_image(img_data: bytes, img_name: str) -> None:
|
||||
try:
|
||||
mime_type = get_image_type_from_bytes(img_data)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Skipping embedded image with unknown format for %s",
|
||||
driveitem.name,
|
||||
)
|
||||
return
|
||||
|
||||
# The only mime type that would be returned by get_image_type_from_bytes that is in
|
||||
# EXCLUDED_IMAGE_TYPES is image/gif.
|
||||
if mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
logger.debug(
|
||||
"Skipping embedded image of excluded type %s for %s",
|
||||
mime_type,
|
||||
driveitem.name,
|
||||
)
|
||||
return
|
||||
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{driveitem.id}_img_{len(sections)}",
|
||||
display_name=img_name or f"{driveitem.name} - image {len(sections)}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
image_section.link = driveitem.web_url
|
||||
sections.append(image_section)
|
||||
|
||||
extraction_result = extract_text_and_images(
|
||||
file=io.BytesIO(content_bytes), file_name=driveitem.name
|
||||
file=io.BytesIO(content_bytes),
|
||||
file_name=driveitem.name,
|
||||
image_callback=_store_embedded_image,
|
||||
)
|
||||
if extraction_result.text_content:
|
||||
sections.append(
|
||||
TextSection(link=driveitem.web_url, text=extraction_result.text_content)
|
||||
)
|
||||
|
||||
for idx, (img_data, img_name) in enumerate(extraction_result.embedded_images):
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{driveitem.id}_img_{idx}",
|
||||
display_name=img_name or f"{driveitem.name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
image_section.link = driveitem.web_url
|
||||
sections.append(image_section)
|
||||
# Any embedded images were stored via the callback; the returned list may be empty.
|
||||
|
||||
if include_permissions and ctx is not None:
|
||||
logger.info(f"Getting external access for {driveitem.name}")
|
||||
@@ -717,6 +776,7 @@ class SharepointConnector(
|
||||
for folder_part in site_descriptor.folder_path.split("/"):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# TODO: consider ways to avoid materializing the entire list of files in memory
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
@@ -825,6 +885,7 @@ class SharepointConnector(
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# Get all items recursively
|
||||
# TODO: consider ways to avoid materializing the entire list of files in memory
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
@@ -973,6 +1034,8 @@ class SharepointConnector(
|
||||
all_pages = pages_data.get("value", [])
|
||||
|
||||
# Handle pagination if there are more pages
|
||||
# TODO: This accumulates all pages in memory and can be heavy on large tenants.
|
||||
# We should process each page incrementally to avoid unbounded growth.
|
||||
while "@odata.nextLink" in pages_data:
|
||||
next_url = pages_data["@odata.nextLink"]
|
||||
response = requests.get(
|
||||
@@ -986,7 +1049,7 @@ class SharepointConnector(
|
||||
|
||||
# Filter pages based on time window if specified
|
||||
if start is not None or end is not None:
|
||||
filtered_pages = []
|
||||
filtered_pages: list[dict[str, Any]] = []
|
||||
for page in all_pages:
|
||||
page_modified = page.get("lastModifiedDateTime")
|
||||
if page_modified:
|
||||
|
||||
@@ -761,7 +761,7 @@ class SlackConnector(
|
||||
Step 2: Loop through each channel. For each channel:
|
||||
Step 2.1: Get messages within the time range.
|
||||
Step 2.2: Process messages in parallel, yield back docs.
|
||||
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
|
||||
Step 2.3: Update checkpoint with new_oldest, seen_thread_ts, and current_channel.
|
||||
Slack returns messages from newest to oldest, so we need to keep track of
|
||||
the latest message we've seen in each channel.
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
@@ -837,7 +837,8 @@ class SlackConnector(
|
||||
|
||||
channel_message_ts = checkpoint.channel_completion_map.get(channel_id)
|
||||
if channel_message_ts:
|
||||
latest = channel_message_ts
|
||||
# Set oldest to the checkpoint timestamp to resume from where we left off
|
||||
oldest = channel_message_ts
|
||||
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
@@ -855,7 +856,8 @@ class SlackConnector(
|
||||
f"{latest=}"
|
||||
)
|
||||
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
# message_batch[0] is the newest message (Slack returns newest to oldest)
|
||||
new_oldest = message_batch[0]["ts"] if message_batch else latest
|
||||
|
||||
num_threads_start = len(seen_thread_ts)
|
||||
|
||||
@@ -906,15 +908,14 @@ class SlackConnector(
|
||||
num_threads_processed = len(seen_thread_ts) - num_threads_start
|
||||
|
||||
# calculate a percentage progress for the current channel by determining
|
||||
# our viable range start and end, and the latest timestamp we are querying
|
||||
# up to
|
||||
new_latest_seconds_epoch = SecondsSinceUnixEpoch(new_latest)
|
||||
if new_latest_seconds_epoch > end:
|
||||
# how much of the time range we've processed so far
|
||||
new_oldest_seconds_epoch = SecondsSinceUnixEpoch(new_oldest)
|
||||
range_start = start if start else max(0, channel_created)
|
||||
if new_oldest_seconds_epoch < range_start:
|
||||
range_complete = 0.0
|
||||
else:
|
||||
range_complete = end - new_latest_seconds_epoch
|
||||
range_complete = new_oldest_seconds_epoch - range_start
|
||||
|
||||
range_start = max(0, channel_created)
|
||||
range_total = end - range_start
|
||||
if range_total <= 0:
|
||||
range_total = 1
|
||||
@@ -935,7 +936,7 @@ class SlackConnector(
|
||||
)
|
||||
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_oldest
|
||||
|
||||
# bypass channels where the first set of messages seen are all bots
|
||||
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch
|
||||
|
||||
@@ -75,7 +75,7 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
ttl_ms: int | None = None
|
||||
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_value: str | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
|
||||
@@ -103,14 +103,21 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
|
||||
)
|
||||
|
||||
retry_after_value = response.headers.get(retry_after_header_name)
|
||||
if not retry_after_value:
|
||||
retry_after_header_value = response.headers.get(retry_after_header_name)
|
||||
if not retry_after_header_value:
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
|
||||
)
|
||||
|
||||
# Handle case where header value might be a list
|
||||
retry_after_value = (
|
||||
retry_after_header_value[0]
|
||||
if isinstance(retry_after_header_value, list)
|
||||
else retry_after_header_value
|
||||
)
|
||||
|
||||
retry_after_value_int = int(
|
||||
retry_after_value[0]
|
||||
retry_after_value
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = retry_after_value_int + jitter
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
@@ -47,14 +51,30 @@ class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
|
||||
|
||||
class ZendeskClient:
|
||||
def __init__(self, subdomain: str, email: str, token: str):
|
||||
def __init__(
|
||||
self,
|
||||
subdomain: str,
|
||||
email: str,
|
||||
token: str,
|
||||
calls_per_minute: int | None = None,
|
||||
):
|
||||
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
|
||||
self.auth = (f"{email}/token", token)
|
||||
self.make_request = request_with_rate_limit(self, calls_per_minute)
|
||||
|
||||
|
||||
def request_with_rate_limit(
|
||||
client: ZendeskClient, max_calls_per_minute: int | None = None
|
||||
) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
|
||||
@retry_builder()
|
||||
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
@(
|
||||
rate_limit_builder(max_calls=max_calls_per_minute, period=60)
|
||||
if max_calls_per_minute
|
||||
else lambda x: x
|
||||
)
|
||||
def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
f"{client.base_url}/{endpoint}", auth=client.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
@@ -72,6 +92,8 @@ class ZendeskClient:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
return make_request
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
@@ -359,11 +381,13 @@ class ZendeskConnector(
|
||||
def __init__(
|
||||
self,
|
||||
content_type: str = "articles",
|
||||
calls_per_minute: int | None = None,
|
||||
) -> None:
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
self.content_tags: dict[str, str] = {}
|
||||
self.calls_per_minute = calls_per_minute
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -375,7 +399,10 @@ class ZendeskConnector(
|
||||
self.subdomain = subdomain
|
||||
|
||||
self.client = ZendeskClient(
|
||||
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
calls_per_minute=self.calls_per_minute,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -39,6 +41,42 @@ HIGHLIGHT_START_CHAR = "\ue000"
|
||||
HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
|
||||
def _should_skip_channel(
|
||||
channel_id: str,
|
||||
allowed_private_channel: str | None,
|
||||
bot_token: str | None,
|
||||
access_token: str,
|
||||
include_dm: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a channel should be skipped if in bot context. When an allowed_private_channel is passed in,
|
||||
all other private channels are filtered out except that specific one.
|
||||
"""
|
||||
if bot_token and not include_dm:
|
||||
try:
|
||||
# Use bot token if available (has full permissions), otherwise fall back to user token
|
||||
token_to_use = bot_token or access_token
|
||||
channel_client = WebClient(token=token_to_use)
|
||||
channel_info = channel_client.conversations_info(channel=channel_id)
|
||||
|
||||
if isinstance(channel_info.data, dict) and not _is_public_channel(
|
||||
channel_info.data
|
||||
):
|
||||
# This is a private channel - filter it out
|
||||
if channel_id != allowed_private_channel:
|
||||
logger.debug(
|
||||
f"Skipping message from private channel {channel_id} "
|
||||
f"(not the allowed private channel: {allowed_private_channel})"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not determine channel type for {channel_id}, filtering out: {e}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
|
||||
# get time filter
|
||||
time_filter = ""
|
||||
@@ -64,11 +102,33 @@ def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
|
||||
]
|
||||
|
||||
|
||||
def _is_public_channel(channel_info: dict[str, Any]) -> bool:
|
||||
"""Check if a channel is public based on its info"""
|
||||
# The channel_info structure has a nested 'channel' object
|
||||
channel = channel_info.get("channel", {})
|
||||
|
||||
is_channel = channel.get("is_channel", False)
|
||||
is_private = channel.get("is_private", False)
|
||||
is_group = channel.get("is_group", False)
|
||||
is_mpim = channel.get("is_mpim", False)
|
||||
is_im = channel.get("is_im", False)
|
||||
|
||||
# A public channel is: a channel that is NOT private, NOT a group, NOT mpim, NOT im
|
||||
is_public = (
|
||||
is_channel and not is_private and not is_group and not is_mpim and not is_im
|
||||
)
|
||||
|
||||
return is_public
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
original_query: SearchQuery,
|
||||
access_token: str,
|
||||
limit: int | None = None,
|
||||
allowed_private_channel: str | None = None,
|
||||
bot_token: str | None = None,
|
||||
include_dm: bool = False,
|
||||
) -> list[SlackMessage]:
|
||||
# query slack
|
||||
slack_client = WebClient(token=access_token)
|
||||
@@ -79,12 +139,22 @@ def query_slack(
|
||||
response.validate()
|
||||
messages: dict[str, Any] = response.get("messages", {})
|
||||
matches: list[dict[str, Any]] = messages.get("matches", [])
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Slack API error in query_slack: {e}")
|
||||
logger.info(f"Successfully used search_messages, found {len(matches)} messages")
|
||||
except SlackApiError as slack_error:
|
||||
logger.error(f"Slack API error in search_messages: {slack_error}")
|
||||
logger.error(
|
||||
f"Slack API error details: status={slack_error.response.status_code}, "
|
||||
f"error={slack_error.response.get('error')}"
|
||||
)
|
||||
if "not_allowed_token_type" in str(slack_error):
|
||||
# Log token type prefix
|
||||
token_prefix = access_token[:4] if len(access_token) >= 4 else "unknown"
|
||||
logger.error(f"TOKEN TYPE ERROR: access_token type: {token_prefix}...")
|
||||
return []
|
||||
|
||||
# convert matches to slack messages
|
||||
slack_messages: list[SlackMessage] = []
|
||||
filtered_count = 0
|
||||
for match in matches:
|
||||
text: str | None = match.get("text")
|
||||
permalink: str | None = match.get("permalink")
|
||||
@@ -92,6 +162,13 @@ def query_slack(
|
||||
channel_id: str | None = match.get("channel", {}).get("id")
|
||||
channel_name: str | None = match.get("channel", {}).get("name")
|
||||
username: str | None = match.get("username")
|
||||
if not username:
|
||||
# Fallback: try to get from user field if username is missing
|
||||
user_info = match.get("user", "")
|
||||
if isinstance(user_info, str) and user_info:
|
||||
username = user_info # Use user ID as fallback
|
||||
else:
|
||||
username = "unknown_user"
|
||||
score: float = match.get("score", 0.0)
|
||||
if ( # can't use any() because of type checking :(
|
||||
not text
|
||||
@@ -103,6 +180,13 @@ def query_slack(
|
||||
):
|
||||
continue
|
||||
|
||||
# Apply channel filtering if needed
|
||||
if _should_skip_channel(
|
||||
channel_id, allowed_private_channel, bot_token, access_token, include_dm
|
||||
):
|
||||
filtered_count += 1
|
||||
continue
|
||||
|
||||
# generate thread id and document id
|
||||
thread_id = (
|
||||
permalink.split("?thread_ts=", 1)[1] if "?thread_ts=" in permalink else None
|
||||
@@ -155,6 +239,11 @@ def query_slack(
|
||||
)
|
||||
)
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"Channel filtering applied: {filtered_count} messages filtered out, {len(slack_messages)} messages kept"
|
||||
)
|
||||
|
||||
return slack_messages
|
||||
|
||||
|
||||
@@ -291,14 +380,40 @@ def slack_retrieval(
|
||||
access_token: str,
|
||||
db_session: Session,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
) -> list[InferenceChunk]:
|
||||
# query slack
|
||||
_, fast_llm = get_default_llms()
|
||||
query_strings = build_slack_queries(query, fast_llm)
|
||||
|
||||
results: list[list[SlackMessage]] = run_functions_tuples_in_parallel(
|
||||
include_dm = False
|
||||
allowed_private_channel = None
|
||||
|
||||
if slack_event_context:
|
||||
channel_type = slack_event_context.channel_type
|
||||
if channel_type == ChannelType.IM: # DM with user
|
||||
include_dm = True
|
||||
if channel_type == ChannelType.PRIVATE_CHANNEL:
|
||||
allowed_private_channel = slack_event_context.channel_id
|
||||
logger.info(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
results = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(query_slack, (query_string, query, access_token, limit))
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
limit,
|
||||
allowed_private_channel,
|
||||
bot_token,
|
||||
include_dm,
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
@@ -307,7 +422,6 @@ def slack_retrieval(
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
# contextualize the slack messages
|
||||
thread_texts: list[str] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(get_contextualized_thread_text, (slack_message, access_token))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@@ -119,8 +121,8 @@ class BaseFilters(BaseModel):
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
@@ -355,6 +357,44 @@ class SearchDoc(BaseModel):
|
||||
secondary_owners: list[str] | None = None
|
||||
is_internet: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_chunks_or_sections(
|
||||
cls,
|
||||
items: "Sequence[InferenceChunk | InferenceSection] | None",
|
||||
) -> list["SearchDoc"]:
|
||||
"""Convert a sequence of InferenceChunk or InferenceSection objects to SearchDoc objects."""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
search_docs = [
|
||||
cls(
|
||||
document_id=(
|
||||
chunk := (
|
||||
item.center_chunk
|
||||
if isinstance(item, InferenceSection)
|
||||
else item
|
||||
)
|
||||
).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=False,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return search_docs
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(*args, **kwargs) # type: ignore
|
||||
initial_dict["updated_at"] = (
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -65,6 +66,7 @@ class SearchPipeline:
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
contextual_pruning_config: ContextualPruningConfig | None = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
@@ -84,6 +86,7 @@ class SearchPipeline:
|
||||
self.contextual_pruning_config: ContextualPruningConfig | None = (
|
||||
contextual_pruning_config
|
||||
)
|
||||
self.slack_context: SlackContext | None = slack_context
|
||||
|
||||
# Preprocessing steps generate this
|
||||
self._search_query: SearchQuery | None = None
|
||||
@@ -162,6 +165,7 @@ class SearchPipeline:
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
||||
slack_context=self.slack_context, # Pass Slack context
|
||||
)
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@@ -166,9 +166,6 @@ def retrieval_preprocessing(
|
||||
)
|
||||
user_file_filters = search_request.user_file_filters
|
||||
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
|
||||
user_folder_ids = (
|
||||
(user_file_filters.user_folder_ids or []) if user_file_filters else []
|
||||
)
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = list(
|
||||
set(user_file_ids) | set([file.id for file in persona.user_files])
|
||||
@@ -176,7 +173,7 @@ def retrieval_preprocessing(
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
project_id=user_file_filters.project_id if user_file_filters else None,
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -329,6 +330,7 @@ def retrieve_chunks(
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
@@ -341,7 +343,11 @@ def retrieve_chunks(
|
||||
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session, user_id, query.filters.source_type, query.filters.document_set
|
||||
db_session,
|
||||
user_id,
|
||||
query.filters.source_type,
|
||||
query.filters.document_set,
|
||||
slack_context,
|
||||
)
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
|
||||
@@ -118,40 +118,6 @@ def inference_section_from_chunks(
|
||||
)
|
||||
|
||||
|
||||
def chunks_or_sections_to_search_docs(
|
||||
items: Sequence[InferenceChunk | InferenceSection] | None,
|
||||
) -> list[SearchDoc]:
|
||||
if not items:
|
||||
return []
|
||||
|
||||
search_docs = [
|
||||
SearchDoc(
|
||||
document_id=(
|
||||
chunk := (
|
||||
item.center_chunk if isinstance(item, InferenceSection) else item
|
||||
)
|
||||
).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=False,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return search_docs
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
|
||||
@@ -28,13 +28,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
from onyx.agents.agent_search.utils import create_citation_format_list
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc as ServerSearchDoc
|
||||
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from onyx.db.models import AgentSearchMetrics
|
||||
from onyx.db.models import AgentSubQuery
|
||||
from onyx.db.models import AgentSubQuestion
|
||||
@@ -47,17 +45,15 @@ from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_best_persona_id_for_user
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import SubQueryDetail
|
||||
from onyx.server.query_and_chat.models import SubQuestionDetail
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
@@ -173,6 +169,8 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
@@ -187,6 +185,11 @@ def get_chat_sessions_by_user(
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
@@ -260,6 +263,7 @@ def create_chat_session(
|
||||
prompt_override: PromptOverride | None = None,
|
||||
onyxbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
project_id: int | None = None,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
@@ -269,6 +273,7 @@ def create_chat_session(
|
||||
prompt_override=prompt_override,
|
||||
onyxbot_flow=onyxbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
db_session.add(chat_session)
|
||||
@@ -859,90 +864,6 @@ def get_db_search_doc_by_document_id(
|
||||
return search_doc
|
||||
|
||||
|
||||
def create_search_doc_from_user_file(
|
||||
db_user_file: UserFile, associated_chat_file: InMemoryChatFile, db_session: Session
|
||||
) -> SearchDoc:
|
||||
"""Create a SearchDoc in the database from a UserFile and return it.
|
||||
This ensures proper ID generation by SQLAlchemy and prevents duplicate key errors.
|
||||
"""
|
||||
blurb = ""
|
||||
if associated_chat_file and associated_chat_file.content:
|
||||
try:
|
||||
# Try to decode as UTF-8, but handle errors gracefully
|
||||
content_sample = associated_chat_file.content[:100]
|
||||
# Remove null bytes which can cause SQL errors
|
||||
content_sample = content_sample.replace(b"\x00", b"")
|
||||
|
||||
# NOTE(rkuo): this used to be "replace" instead of strict, but
|
||||
# that would bypass the binary handling below
|
||||
blurb = content_sample.decode("utf-8", errors="strict")
|
||||
except Exception:
|
||||
# If decoding fails completely, provide a generic description
|
||||
blurb = f"[Binary file: {db_user_file.name}]"
|
||||
|
||||
db_search_doc = SearchDoc(
|
||||
document_id=db_user_file.document_id,
|
||||
chunk_ind=0, # Default to 0 for user files
|
||||
semantic_id=db_user_file.name,
|
||||
link=db_user_file.link_url,
|
||||
blurb=blurb,
|
||||
source_type=DocumentSource.FILE, # Assuming internal source for user files
|
||||
boost=0, # Default boost
|
||||
hidden=False, # Default visibility
|
||||
doc_metadata={}, # Empty metadata
|
||||
score=0.0, # Default score of 0.0 instead of None
|
||||
is_relevant=None, # No relevance initially
|
||||
relevance_explanation=None, # No explanation initially
|
||||
match_highlights=[], # No highlights initially
|
||||
updated_at=db_user_file.created_at, # Use created_at as updated_at
|
||||
primary_owners=[], # Empty list instead of None
|
||||
secondary_owners=[], # Empty list instead of None
|
||||
is_internet=False, # Not from internet
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
db_session.flush() # Get the ID but don't commit yet
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
def translate_db_user_file_to_search_doc(
|
||||
db_user_file: UserFile, associated_chat_file: InMemoryChatFile
|
||||
) -> SearchDoc:
|
||||
blurb = ""
|
||||
if associated_chat_file and associated_chat_file.content:
|
||||
try:
|
||||
# Try to decode as UTF-8, but handle errors gracefully
|
||||
content_sample = associated_chat_file.content[:100]
|
||||
# Remove null bytes which can cause SQL errors
|
||||
content_sample = content_sample.replace(b"\x00", b"")
|
||||
blurb = content_sample.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
# If decoding fails completely, provide a generic description
|
||||
blurb = f"[Binary file: {db_user_file.name}]"
|
||||
|
||||
return SearchDoc(
|
||||
# Don't set ID - let SQLAlchemy auto-generate it
|
||||
document_id=db_user_file.document_id,
|
||||
chunk_ind=0, # Default to 0 for user files
|
||||
semantic_id=db_user_file.name,
|
||||
link=db_user_file.link_url,
|
||||
blurb=blurb,
|
||||
source_type=DocumentSource.FILE, # Assuming internal source for user files
|
||||
boost=0, # Default boost
|
||||
hidden=False, # Default visibility
|
||||
doc_metadata={}, # Empty metadata
|
||||
score=0.0, # Default score of 0.0 instead of None
|
||||
is_relevant=None, # No relevance initially
|
||||
relevance_explanation=None, # No explanation initially
|
||||
match_highlights=[], # No highlights initially
|
||||
updated_at=db_user_file.created_at, # Use created_at as updated_at
|
||||
primary_owners=[], # Empty list instead of None
|
||||
secondary_owners=[], # Empty list instead of None
|
||||
is_internet=False, # Not from internet
|
||||
)
|
||||
|
||||
|
||||
def translate_db_search_doc_to_server_search_doc(
|
||||
db_search_doc: SearchDoc,
|
||||
remove_doc_content: bool = False,
|
||||
@@ -1147,7 +1068,7 @@ def log_agent_sub_question_results(
|
||||
db_session.add(sub_query_object)
|
||||
db_session.commit()
|
||||
|
||||
search_docs = chunks_or_sections_to_search_docs(
|
||||
search_docs = ServerSearchDoc.from_chunks_or_sections(
|
||||
sub_query.retrieved_documents
|
||||
)
|
||||
for doc in search_docs:
|
||||
@@ -1223,12 +1144,29 @@ def create_search_doc_from_inference_section(
|
||||
def create_search_doc_from_saved_search_doc(
|
||||
saved_search_doc: SavedSearchDoc,
|
||||
) -> SearchDoc:
|
||||
"""Convert SavedSearchDoc to SearchDoc by excluding the additional fields"""
|
||||
data = saved_search_doc.model_dump()
|
||||
# Remove the fields that are specific to SavedSearchDoc
|
||||
data.pop("db_doc_id", None)
|
||||
# Keep score since SearchDoc has it as an optional field
|
||||
return SearchDoc(**data)
|
||||
"""Convert SavedSearchDoc (server model) into DB SearchDoc with correct field mapping."""
|
||||
return SearchDoc(
|
||||
document_id=saved_search_doc.document_id,
|
||||
chunk_ind=saved_search_doc.chunk_ind,
|
||||
# Map Pydantic semantic_identifier -> DB semantic_id; ensure non-null
|
||||
semantic_id=saved_search_doc.semantic_identifier or "Unknown",
|
||||
link=saved_search_doc.link,
|
||||
blurb=saved_search_doc.blurb,
|
||||
source_type=saved_search_doc.source_type,
|
||||
boost=saved_search_doc.boost,
|
||||
hidden=saved_search_doc.hidden,
|
||||
# Map metadata -> doc_metadata (DB column name)
|
||||
doc_metadata=saved_search_doc.metadata,
|
||||
# SavedSearchDoc.score exists and defaults to 0.0
|
||||
score=saved_search_doc.score or 0.0,
|
||||
match_highlights=saved_search_doc.match_highlights,
|
||||
updated_at=saved_search_doc.updated_at,
|
||||
primary_owners=saved_search_doc.primary_owners,
|
||||
secondary_owners=saved_search_doc.secondary_owners,
|
||||
is_internet=saved_search_doc.is_internet,
|
||||
is_relevant=saved_search_doc.is_relevant,
|
||||
relevance_explanation=saved_search_doc.relevance_explanation,
|
||||
)
|
||||
|
||||
|
||||
def update_db_session_with_messages(
|
||||
@@ -1251,7 +1189,6 @@ def update_db_session_with_messages(
|
||||
research_answer_purpose: ResearchAnswerPurpose | None = None,
|
||||
commit: bool = False,
|
||||
) -> ChatMessage:
|
||||
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.models import StatusResponse
|
||||
@@ -805,31 +804,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_connector_credential_pairs_with_user_files(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all connector credential pairs that have associated user files.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
List of ConnectorCredentialPair objects that have user files
|
||||
"""
|
||||
return (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_userfiles_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
stmt = delete(UserFile).where(UserFile.cc_pair_id == cc_pair_id)
|
||||
db_session.execute(stmt)
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -259,31 +260,48 @@ def get_document_counts_for_cc_pairs(
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
# Prepare a list of (connector_id, credential_id) tuples
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(cc_ids),
|
||||
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
|
||||
# Batch to avoid generating extremely large IN clauses that can blow Postgres stack depth
|
||||
batch_size = 1000
|
||||
aggregated_counts: dict[tuple[int, int], int] = {}
|
||||
|
||||
for start_idx in range(0, len(cc_ids), batch_size):
|
||||
batch = cc_ids[start_idx : start_idx + batch_size]
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(batch),
|
||||
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
|
||||
aggregated_counts[(connector_id, credential_id)] = cnt
|
||||
|
||||
# Convert aggregated results back to the expected sequence of tuples
|
||||
return [
|
||||
(connector_id, credential_id, cnt)
|
||||
for (connector_id, credential_id), cnt in aggregated_counts.items()
|
||||
]
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
@@ -296,6 +314,72 @@ def get_document_counts_for_cc_pairs_parallel(
|
||||
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
|
||||
|
||||
|
||||
def _get_document_counts_for_cc_pairs_batch(
|
||||
batch: list[tuple[int, int]],
|
||||
) -> list[tuple[int, int, int]]:
|
||||
"""Worker for parallel execution: opens its own session per batch."""
|
||||
if not batch:
|
||||
return []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(batch),
|
||||
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_document_counts_for_cc_pairs_batched_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
batch_size: int = 1000,
|
||||
max_workers: int | None = None,
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Parallel variant that batches the IN-clause and runs batches concurrently.
|
||||
|
||||
Opens an isolated DB session per batch to avoid sharing a session across threads.
|
||||
"""
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
|
||||
batches: list[list[tuple[int, int]]] = [
|
||||
cc_ids[i : i + batch_size] for i in range(0, len(cc_ids), batch_size)
|
||||
]
|
||||
|
||||
funcs = [(_get_document_counts_for_cc_pairs_batch, (batch,)) for batch in batches]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
functions_with_args=funcs, max_workers=max_workers
|
||||
)
|
||||
|
||||
aggregated_counts: dict[tuple[int, int], int] = {}
|
||||
for batch_result in results:
|
||||
if not batch_result:
|
||||
continue
|
||||
for connector_id, credential_id, cnt in batch_result:
|
||||
aggregated_counts[(connector_id, credential_id)] = cnt
|
||||
|
||||
return [
|
||||
(connector_id, credential_id, cnt)
|
||||
for (connector_id, credential_id), cnt in aggregated_counts.items()
|
||||
]
|
||||
|
||||
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
|
||||
@@ -131,3 +131,10 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
# good reason to specify anything else
|
||||
BFLOAT16 = "bfloat16"
|
||||
FLOAT = "float"
|
||||
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -62,6 +62,7 @@ from onyx.db.enums import (
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
MCPAuthenticationType,
|
||||
UserFileStatus,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -209,9 +210,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
)
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
input_prompts: Mapped[list["InputPrompt"]] = relationship(
|
||||
"InputPrompt", back_populates="user"
|
||||
@@ -229,8 +227,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
back_populates="creator",
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder", back_populates="user"
|
||||
projects: Mapped[list["UserProject"]] = relationship(
|
||||
"UserProject", back_populates="user"
|
||||
)
|
||||
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
|
||||
# MCP servers accessible to this user
|
||||
@@ -539,10 +537,6 @@ class ConnectorCredentialPair(Base):
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
user_file: Mapped["UserFile"] = relationship(
|
||||
"UserFile", back_populates="cc_pair", uselist=False
|
||||
)
|
||||
|
||||
background_errors: Mapped[list["BackgroundError"]] = relationship(
|
||||
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -1459,9 +1453,6 @@ class FederatedConnector(Base):
|
||||
|
||||
|
||||
class FederatedConnectorOAuthToken(Base):
|
||||
"""NOTE: in the future, can be made more general to support OAuth tokens
|
||||
for actions."""
|
||||
|
||||
__tablename__ = "federated_connector_oauth_token"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
@@ -2037,9 +2028,6 @@ class ChatSession(Base):
|
||||
Enum(ChatSessionSharedStatus, native_enum=False),
|
||||
default=ChatSessionSharedStatus.PRIVATE,
|
||||
)
|
||||
folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
|
||||
|
||||
@@ -2047,6 +2035,14 @@ class ChatSession(Base):
|
||||
String, nullable=True, default=None
|
||||
)
|
||||
|
||||
project_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_project.id"), nullable=True
|
||||
)
|
||||
|
||||
project: Mapped["UserProject"] = relationship(
|
||||
"UserProject", back_populates="chat_sessions", foreign_keys=[project_id]
|
||||
)
|
||||
|
||||
# the latest "overrides" specified by the user. These take precedence over
|
||||
# the attached persona. However, overrides specified directly in the
|
||||
# `send-message` call will take precedence over these.
|
||||
@@ -2072,9 +2068,6 @@ class ChatSession(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
)
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -2178,33 +2171,6 @@ class ChatMessage(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
"""For organizing chat sessions"""
|
||||
|
||||
__tablename__ = "chat_folder"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Only null if auth is off
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_folders")
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="folder"
|
||||
)
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, ChatFolder):
|
||||
return NotImplemented
|
||||
if self.display_priority == other.display_priority:
|
||||
# Bigger ID (created later) show earlier
|
||||
return self.id > other.id
|
||||
return self.display_priority < other.display_priority
|
||||
|
||||
|
||||
class AgentSubQuestion(Base):
|
||||
"""
|
||||
A sub-question is a question that is asked of the LLM to gather supporting
|
||||
@@ -2636,11 +2602,6 @@ class Persona(Base):
|
||||
secondary="persona__user_file",
|
||||
back_populates="assistants",
|
||||
)
|
||||
user_folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder",
|
||||
secondary="persona__user_folder",
|
||||
back_populates="assistants",
|
||||
)
|
||||
labels: Mapped[list["PersonaLabel"]] = relationship(
|
||||
"PersonaLabel",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
@@ -2658,20 +2619,11 @@ class Persona(Base):
|
||||
)
|
||||
|
||||
|
||||
class Persona__UserFolder(Base):
|
||||
__tablename__ = "persona__user_folder"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_folder_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_folder.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class Persona__UserFile(Base):
|
||||
__tablename__ = "persona__user_file"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_file_id: Mapped[int] = mapped_column(
|
||||
user_file_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user_file.id"), primary_key=True
|
||||
)
|
||||
|
||||
@@ -2783,6 +2735,7 @@ class SlackBot(Base):
|
||||
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
|
||||
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
|
||||
"SlackChannelConfig",
|
||||
@@ -3276,23 +3229,37 @@ class InputPrompt__User(Base):
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class UserFolder(Base):
|
||||
__tablename__ = "user_folder"
|
||||
class Project__UserFile(Base):
|
||||
__tablename__ = "project__user_file"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_project.id"), primary_key=True
|
||||
)
|
||||
user_file_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user_file.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class UserProject(Base):
|
||||
__tablename__ = "user_project"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
description: Mapped[str] = mapped_column(nullable=False)
|
||||
description: Mapped[str] = mapped_column(nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="folders")
|
||||
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
|
||||
assistants: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__UserFolder.__table__,
|
||||
back_populates="user_folders",
|
||||
user: Mapped["User"] = relationship(back_populates="projects")
|
||||
user_files: Mapped[list["UserFile"]] = relationship(
|
||||
"UserFile",
|
||||
secondary=Project__UserFile.__table__,
|
||||
back_populates="projects",
|
||||
)
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="project", lazy="selectin"
|
||||
)
|
||||
instructions: Mapped[str] = mapped_column(String)
|
||||
|
||||
|
||||
class UserDocument(str, Enum):
|
||||
@@ -3304,17 +3271,13 @@ class UserDocument(str, Enum):
|
||||
class UserFile(Base):
|
||||
__tablename__ = "user_file"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
assistants: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__UserFile.__table__,
|
||||
back_populates="user_files",
|
||||
)
|
||||
folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
@@ -3322,17 +3285,38 @@ class UserFile(Base):
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="files")
|
||||
folder: Mapped["UserFolder"] = relationship(back_populates="files")
|
||||
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=True, unique=True
|
||||
file_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
status: Mapped[UserFileStatus] = mapped_column(
|
||||
Enum(UserFileStatus, native_enum=False, name="userfilestatus"),
|
||||
nullable=False,
|
||||
default=UserFileStatus.PROCESSING,
|
||||
)
|
||||
cc_pair: Mapped["ConnectorCredentialPair"] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="user_file"
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
last_accessed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
content_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_id_migrated: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
|
||||
projects: Mapped[list["UserProject"]] = relationship(
|
||||
"UserProject",
|
||||
secondary=Project__UserFile.__table__,
|
||||
back_populates="user_files",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
@@ -237,6 +236,16 @@ def create_update_persona(
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
if create_persona_request.user_file_ids is not None:
|
||||
try:
|
||||
converted_user_file_ids = [
|
||||
UUID(str_id) for str_id in create_persona_request.user_file_ids
|
||||
]
|
||||
except Exception:
|
||||
raise ValueError("Invalid user_file_ids; must be UUID strings")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
@@ -264,8 +273,7 @@ def create_update_persona(
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=create_persona_request.user_file_ids,
|
||||
user_folder_ids=create_persona_request.user_folder_ids,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -504,8 +512,7 @@ def upsert_persona(
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[int] | None = None,
|
||||
user_folder_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
@@ -566,17 +573,6 @@ def upsert_persona(
|
||||
if not user_files and user_file_ids:
|
||||
raise ValueError("user_files not found")
|
||||
|
||||
# Fetch and attach user_folders by IDs
|
||||
user_folders = None
|
||||
if user_folder_ids is not None:
|
||||
user_folders = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id.in_(user_folder_ids))
|
||||
.all()
|
||||
)
|
||||
if not user_folders and user_folder_ids:
|
||||
raise ValueError("user_folders not found")
|
||||
|
||||
labels = None
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
@@ -644,10 +640,6 @@ def upsert_persona(
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
|
||||
if user_folder_ids is not None:
|
||||
existing_persona.user_folders.clear()
|
||||
existing_persona.user_folders = user_folders or []
|
||||
|
||||
# We should only update display priority if it is not already set
|
||||
if existing_persona.display_priority is None:
|
||||
existing_persona.display_priority = display_priority
|
||||
@@ -686,7 +678,6 @@ def upsert_persona(
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
user_folders=user_folders or [],
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
177
backend/onyx/db/projects.py
Normal file
177
backend/onyx/db/projects.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
non_accepted_files: list[str]
|
||||
unsupported_files: list[str]
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
project_id: int | None,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
link_url: str | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
user_files = []
|
||||
non_accepted_files = categorized_files.non_accepted
|
||||
unsupported_files = categorized_files.unsupported
|
||||
|
||||
# Pair returned storage paths with the same set of acceptable files we uploaded
|
||||
for file_path, file in zip(
|
||||
upload_response.file_paths, categorized_files.acceptable
|
||||
):
|
||||
new_id = uuid.uuid4()
|
||||
new_file = UserFile(
|
||||
id=new_id,
|
||||
user_id=user.id if user else None,
|
||||
file_id=file_path,
|
||||
document_id=str(new_id),
|
||||
name=file.filename,
|
||||
token_count=categorized_files.acceptable_file_to_token_count[
|
||||
file.filename or ""
|
||||
],
|
||||
link_url=link_url,
|
||||
content_type=file.content_type,
|
||||
file_type=file.content_type,
|
||||
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
# Persist the UserFile first to satisfy FK constraints for association table
|
||||
db_session.add(new_file)
|
||||
db_session.flush()
|
||||
if project_id:
|
||||
project_to_user_file = Project__UserFile(
|
||||
project_id=project_id,
|
||||
user_file_id=new_file.id,
|
||||
)
|
||||
db_session.add(project_to_user_file)
|
||||
user_files.append(new_file)
|
||||
db_session.commit()
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
)
|
||||
|
||||
|
||||
def upload_files_to_user_files_with_indexing(
|
||||
files: List[UploadFile],
|
||||
project_id: int | None,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
categorized_files_result = create_user_files(files, project_id, user, db_session)
|
||||
user_files = categorized_files_result.user_files
|
||||
non_accepted_files = categorized_files_result.non_accepted_files
|
||||
unsupported_files = categorized_files_result.unsupported_files
|
||||
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
if non_accepted_files:
|
||||
for filename in non_accepted_files:
|
||||
logger.warning(f"Non-accepted file: {filename}")
|
||||
if unsupported_files:
|
||||
for filename in unsupported_files:
|
||||
logger.warning(f"Unsupported file: {filename}")
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
)
|
||||
|
||||
|
||||
def check_project_ownership(
|
||||
project_id: int, user_id: UUID | None, db_session: Session
|
||||
) -> bool:
|
||||
return (
|
||||
db_session.query(UserProject)
|
||||
.filter(UserProject.id == project_id, UserProject.user_id == user_id)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def get_user_files_from_project(
|
||||
project_id: int, user_id: UUID | None, db_session: Session
|
||||
) -> list[UserFile]:
|
||||
# First check if the user owns the project
|
||||
if not check_project_ownership(project_id, user_id, db_session):
|
||||
return []
|
||||
|
||||
return (
|
||||
db_session.query(UserFile)
|
||||
.join(Project__UserFile)
|
||||
.filter(Project__UserFile.project_id == project_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_project_instructions(db_session: Session, project_id: int | None) -> str | None:
|
||||
"""Return the project's instruction text from the project, else None.
|
||||
|
||||
Safe helper that swallows DB errors and returns None on any failure.
|
||||
"""
|
||||
if not project_id:
|
||||
return None
|
||||
try:
|
||||
project = (
|
||||
db_session.query(UserProject)
|
||||
.filter(UserProject.id == project_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if not project or not project.instructions:
|
||||
return None
|
||||
instructions = project.instructions.strip()
|
||||
return instructions or None
|
||||
except Exception:
|
||||
return None
|
||||
@@ -12,12 +12,14 @@ def insert_slack_bot(
|
||||
enabled: bool,
|
||||
bot_token: str,
|
||||
app_token: str,
|
||||
user_token: str | None = None,
|
||||
) -> SlackBot:
|
||||
slack_bot = SlackBot(
|
||||
name=name,
|
||||
enabled=enabled,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
user_token=user_token,
|
||||
)
|
||||
db_session.add(slack_bot)
|
||||
db_session.commit()
|
||||
@@ -32,6 +34,7 @@ def update_slack_bot(
|
||||
enabled: bool,
|
||||
bot_token: str,
|
||||
app_token: str,
|
||||
user_token: str | None = None,
|
||||
) -> SlackBot:
|
||||
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
|
||||
if slack_bot is None:
|
||||
@@ -42,6 +45,7 @@ def update_slack_bot(
|
||||
slack_bot.enabled = enabled
|
||||
slack_bot.bot_token = bot_token
|
||||
slack_bot.app_token = app_token
|
||||
slack_bot.user_token = user_token
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -74,15 +78,3 @@ def remove_slack_bot(
|
||||
|
||||
def fetch_slack_bots(db_session: Session) -> Sequence[SlackBot]:
|
||||
return db_session.scalars(select(SlackBot)).all()
|
||||
|
||||
|
||||
def fetch_slack_bot_tokens(
|
||||
db_session: Session, slack_bot_id: int
|
||||
) -> dict[str, str] | None:
|
||||
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
|
||||
if not slack_bot:
|
||||
return None
|
||||
return {
|
||||
"app_token": slack_bot.app_token,
|
||||
"bot_token": slack_bot.bot_token,
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -10,21 +9,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Tool
|
||||
from onyx.server.features.tool.models import Header
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
pass
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -124,15 +114,7 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
|
||||
|
||||
def get_builtin_tool(
|
||||
db_session: Session,
|
||||
tool_type: Type[
|
||||
Union[
|
||||
"SearchTool",
|
||||
"ImageGenerationTool",
|
||||
"WebSearchTool",
|
||||
"KnowledgeGraphTool",
|
||||
"OktaProfileTool",
|
||||
]
|
||||
],
|
||||
tool_type: Type[BUILT_IN_TOOL_TYPES],
|
||||
) -> Tool:
|
||||
"""
|
||||
Retrieves a built-in tool from the database based on the tool type.
|
||||
|
||||
87
backend/onyx/db/user_file.py
Normal file
87
backend/onyx/db/user_file.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import UserFile
|
||||
|
||||
|
||||
def fetch_chunk_counts_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""
|
||||
Return a list of (user_file_id, chunk_count) tuples.
|
||||
If a user_file_id is not found in the database, it will be returned with a chunk_count of 0.
|
||||
"""
|
||||
stmt = select(UserFile.id, UserFile.chunk_count).where(
|
||||
UserFile.id.in_(user_file_ids)
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
# Create a dictionary of user_file_id to chunk_count
|
||||
chunk_counts = {str(row.id): row.chunk_count or 0 for row in results}
|
||||
|
||||
# Return a list of tuples, preserving `None` for documents not found or with
|
||||
# an unknown chunk count. Callers should handle the `None` case and fall
|
||||
# back to an existence check against the vector DB if necessary.
|
||||
return [
|
||||
(user_file_id, chunk_counts.get(user_file_id, 0))
|
||||
for user_file_id in user_file_ids
|
||||
]
|
||||
|
||||
|
||||
def calculate_user_files_token_count(file_ids: list[UUID], db_session: Session) -> int:
|
||||
"""Calculate total token count for specified files"""
|
||||
total_tokens = 0
|
||||
|
||||
# Get tokens from individual files
|
||||
if file_ids:
|
||||
file_tokens = (
|
||||
db_session.query(func.sum(UserFile.token_count))
|
||||
.filter(UserFile.id.in_(file_ids))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
total_tokens += file_tokens
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
def fetch_user_project_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch user project ids for specified user files"""
|
||||
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [project.id for project in user_file.projects]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
def update_last_accessed_at_for_user_files(
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update `last_accessed_at` to now (UTC) for the given user files."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
(
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.update({UserFile.last_accessed_at: now}, synchronize_session=False)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_file_id_by_user_file_id(user_file_id: str, db_session: Session) -> str | None:
|
||||
user_file = db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
||||
if user_file:
|
||||
return user_file.file_id
|
||||
return None
|
||||
@@ -116,8 +116,7 @@ class VespaDocumentUserFields:
|
||||
Fields that are specific to the user who is indexing the document.
|
||||
"""
|
||||
|
||||
user_file_id: str | None = None
|
||||
user_folder_id: str | None = None
|
||||
user_projects: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -176,6 +176,11 @@ schema {{ schema_name }} {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field user_project type array<int> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -68,8 +68,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import YQL_BASE
|
||||
@@ -766,12 +765,9 @@ class VespaIndex(DocumentIndex):
|
||||
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
|
||||
|
||||
if user_fields is not None:
|
||||
if user_fields.user_file_id is not None:
|
||||
update_dict["fields"][USER_FILE] = {"assign": user_fields.user_file_id}
|
||||
|
||||
if user_fields.user_folder_id is not None:
|
||||
update_dict["fields"][USER_FOLDER] = {
|
||||
"assign": user_fields.user_folder_id
|
||||
if user_fields.user_projects is not None:
|
||||
update_dict["fields"][USER_PROJECT] = {
|
||||
"assign": user_fields.user_projects
|
||||
}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
|
||||
@@ -51,8 +51,7 @@ from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import TITLE
|
||||
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -208,8 +207,7 @@ def _index_vespa_chunk(
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
# still called `image_file_name` in Vespa for backwards compatibility
|
||||
IMAGE_FILE_NAME: chunk.image_file_id,
|
||||
USER_FILE: chunk.user_file if chunk.user_file is not None else None,
|
||||
USER_FOLDER: chunk.user_folder if chunk.user_folder is not None else None,
|
||||
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -14,8 +14,7 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -138,6 +137,18 @@ def build_vespa_filters(
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
) -> str:
|
||||
if project_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -172,10 +183,14 @@ def build_vespa_filters(
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
# New: user_file_ids as integer filters
|
||||
filter_str += _build_int_or_filters(USER_FILE, filters.user_file_ids)
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
|
||||
filter_str += _build_int_or_filters(USER_FOLDER, filters.user_folder_ids)
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
@@ -55,6 +55,7 @@ ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
USER_PROJECT = "user_project"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user