Compare commits

...

106 Commits

Author SHA1 Message Date
Raunak Bhagat
68e19383b9 Edit props passed to ActionToggle 2025-09-28 12:04:53 -07:00
Raunak Bhagat
f01f3d6e77 Add back assistants selector 2025-09-28 12:00:25 -07:00
Raunak Bhagat
42b57bf804 Add back FilePicker and DeepResearchToggle 2025-09-28 09:50:21 -07:00
Raunak Bhagat
d4b696d17d Clean up some more code 2025-09-27 23:34:13 -07:00
Raunak Bhagat
84b774f376 Edit ChatInputBar and fix build errors 2025-09-27 23:27:47 -07:00
Raunak Bhagat
921f90d013 Fix all build errors 2025-09-27 21:50:53 -07:00
Raunak Bhagat
f6c09572eb Finish more integrations of Projects into Refresh 2025-09-27 21:34:11 -07:00
Raunak Bhagat
b2f90db0b8 Merge in other files from main 2025-09-26 17:10:55 -07:00
Raunak Bhagat
ee7b33b382 Merge main in 2025-09-26 17:09:50 -07:00
Raunak Bhagat
db37c35030 Update imports 2025-09-26 12:11:08 -07:00
Raunak Bhagat
4f105d002c Finish refreshing HumanMessage 2025-09-26 11:01:42 -07:00
Raunak Bhagat
7e9095b976 Whole bunch of refactors 2025-09-25 22:58:33 -07:00
Raunak Bhagat
bce6741ee6 Fix build errors 2025-09-25 22:27:30 -07:00
Raunak Bhagat
b5a188ee5e Add new utility components 2025-09-25 22:22:34 -07:00
Raunak Bhagat
8578d89a24 Update MessageSwitcher 2025-09-25 22:20:27 -07:00
Raunak Bhagat
e42840775a Add tooltip to IconButton 2025-09-25 22:18:51 -07:00
Raunak Bhagat
c907971d5b Saving changes 2025-09-25 22:17:56 -07:00
Raunak Bhagat
44a565bbfc Add new icons 2025-09-25 22:13:18 -07:00
Raunak Bhagat
ea4219806d Edit main page 2025-09-25 15:08:49 -07:00
Raunak Bhagat
ccc76413c6 Fix build errors 2025-09-25 13:01:41 -07:00
Raunak Bhagat
12a2786ff6 Fix build errors 2025-09-25 12:53:12 -07:00
Raunak Bhagat
882c294e74 saving changes 2025-09-25 10:57:55 -07:00
Raunak Bhagat
0e2b6cf193 Remove more unused files 2025-09-24 19:54:52 -07:00
Raunak Bhagat
c7ae8bd783 Remove unused files 2025-09-24 19:51:19 -07:00
Raunak Bhagat
3b10dd4b22 Edit LLMPopover + dependent components 2025-09-24 19:33:00 -07:00
Raunak Bhagat
6ea886fc85 Add new SelectButton 2025-09-24 17:59:40 -07:00
Raunak Bhagat
f8c89fc750 Fix up ChatInputBar + lots of cleanup 2025-09-24 17:05:02 -07:00
Raunak Bhagat
cace80ffaa Fix sidebar folding 2025-09-24 13:57:18 -07:00
Raunak Bhagat
e628033885 Reattach handler 2025-09-24 13:37:11 -07:00
Raunak Bhagat
22bb4b6d98 Add emphasis 2025-09-24 13:33:02 -07:00
Raunak Bhagat
9afffc2de4 Implement proper grouping hierarchies for buttons 2025-09-24 13:28:26 -07:00
Raunak Bhagat
2c1193f975 Fix button variants + subvariants 2025-09-24 13:00:00 -07:00
Raunak Bhagat
b192542c85 Fix bug in which folded state would stil render text 2025-09-24 12:18:38 -07:00
Raunak Bhagat
d8821b8ccc Prevent click propagation 2025-09-24 12:14:58 -07:00
Raunak Bhagat
a007369bd5 Edit how renaming UI is rendered (using input instead of textarea) 2025-09-24 12:03:19 -07:00
Raunak Bhagat
2f65629f51 More edits to AppSidebar 2025-09-24 11:48:04 -07:00
Raunak Bhagat
7701ae2112 Re-implement buttons + clean up look 2025-09-24 11:15:38 -07:00
Raunak Bhagat
01a3a256e9 Update README 2025-09-23 15:57:28 -07:00
Raunak Bhagat
0d55febaa7 Remove strokeOpacity from icons 2025-09-23 15:56:16 -07:00
Raunak Bhagat
bdafbfe0e8 Edit AdminSidebar width 2025-09-22 11:22:25 -07:00
Raunak Bhagat
278fd0e153 Fix error in which message would not be updated after edit 2025-09-22 11:04:15 -07:00
Raunak Bhagat
a4bb97bc22 Fix editing modal 2025-09-22 10:53:55 -07:00
Raunak Bhagat
8063d9a75e Edit KG configuration page 2025-09-22 05:04:42 -07:00
Raunak Bhagat
1ffaba12f0 Fix height of search bar 2025-09-22 04:45:41 -07:00
Raunak Bhagat
26f8660663 Fix search bar 2025-09-22 04:40:12 -07:00
Raunak Bhagat
d6504ed578 Update search-settings page 2025-09-22 03:59:35 -07:00
Raunak Bhagat
7fcc2c9d35 Clean up more admin stuff 2025-09-22 03:21:48 -07:00
Raunak Bhagat
46e8f925fe Clean up AdminSidebar 2025-09-21 22:33:26 -07:00
Raunak Bhagat
5ec1f61839 Add user settings 2025-09-19 20:07:11 -07:00
Raunak Bhagat
df950963a7 Edit SvgMoreHorizontal SVG size 2025-09-19 19:47:06 -07:00
Raunak Bhagat
93208a66ac Edit settings popup state transitions 2025-09-19 19:30:39 -07:00
Raunak Bhagat
a4819e07e7 Small bug fixes 2025-09-19 19:19:37 -07:00
Raunak Bhagat
f642ace40c Implement logout 2025-09-19 19:16:24 -07:00
Raunak Bhagat
9b430ae2d5 Implement notifications 2025-09-19 19:03:48 -07:00
Raunak Bhagat
05f3f878b2 Edit edit/delete modals 2025-09-19 17:54:24 -07:00
Raunak Bhagat
df17c5352e Edit active colours 2025-09-19 16:46:08 -07:00
Raunak Bhagat
bcfb0f3cf3 Remove commented out state 2025-09-19 16:04:13 -07:00
Raunak Bhagat
38468c1dc4 Fix AgentsModal 2025-09-19 15:55:40 -07:00
Raunak Bhagat
8550a9c5e3 Cleanup sidebar a bit more 2025-09-19 14:33:36 -07:00
Raunak Bhagat
fe0c60e50d Fix UX around naming chats 2025-09-19 14:03:02 -07:00
Raunak Bhagat
4ecc151a02 Fix up chat-renaming 2025-09-19 13:49:08 -07:00
Raunak Bhagat
d08becead5 Saving changes 2025-09-19 13:34:54 -07:00
Raunak Bhagat
a429f852d5 Reduce height of buttons 2025-09-19 08:05:50 -07:00
Raunak Bhagat
a856f27fae Saving changes 2025-09-18 20:19:25 -07:00
Raunak Bhagat
d0d8027928 Edit popups in sidebar buttons 2025-09-18 19:53:33 -07:00
Raunak Bhagat
bd1671f1a1 Edit popovers and add new icons 2025-09-18 19:05:16 -07:00
Raunak Bhagat
e236c67678 Fix build errors 2025-09-18 17:28:19 -07:00
Raunak Bhagat
683956697a More UI fixes and tweaks 2025-09-18 16:57:47 -07:00
Raunak Bhagat
fb1e303ffc Fix ordering bug 2025-09-18 16:09:25 -07:00
Raunak Bhagat
729d4fafd1 Remove client directive 2025-09-18 15:52:16 -07:00
Raunak Bhagat
40c60282d0 Update agents modal and general structure of app 2025-09-18 15:19:35 -07:00
Raunak Bhagat
2141fd2c6e More edits to styling + colours 2025-09-18 12:45:50 -07:00
Raunak Bhagat
9aeba96043 Update state management 2025-09-16 19:34:48 -07:00
Raunak Bhagat
b431de5141 Update hover state for buttons 2025-09-16 19:06:05 -07:00
Raunak Bhagat
d1a6340cfc Add new chat handler 2025-09-16 17:45:21 -07:00
Raunak Bhagat
ccf382ef4f Edit spacing 2025-09-16 17:41:56 -07:00
Raunak Bhagat
c31997b9b2 Save folded state to localStorage 2025-09-16 17:40:11 -07:00
Raunak Bhagat
ab31795a46 Recenter icon when title is hidden 2025-09-16 17:35:40 -07:00
Raunak Bhagat
b3beca63dc Make headers sticky 2025-09-16 17:33:21 -07:00
Raunak Bhagat
cc6d54c1e6 Add loading state for Truncated component + fix spacings 2025-09-16 17:28:54 -07:00
Raunak Bhagat
ee12c0c5de Fix scrolling issue 2025-09-16 17:00:40 -07:00
Raunak Bhagat
d48912a05d Fix errors 2025-09-16 15:50:48 -07:00
Raunak Bhagat
c079072676 Remove unnecessary file + make HistorySidebar be smart 2025-09-16 15:48:15 -07:00
Raunak Bhagat
952f6bfb37 Delete unused files 2025-09-16 15:42:36 -07:00
Raunak Bhagat
0714e4bb4e Fix dnd 2025-09-16 15:39:47 -07:00
Raunak Bhagat
ae577f0f44 Add AgentsModal 2025-09-16 15:35:59 -07:00
Raunak Bhagat
0705d584d8 Update user hover-card 2025-09-16 14:52:54 -07:00
Raunak Bhagat
36e391e557 Add folded sidebar (+ shortcuts) 2025-09-16 13:50:49 -07:00
Raunak Bhagat
1efce594b5 Clean up truncation + buttons 2025-09-16 13:02:16 -07:00
Raunak Bhagat
67ac53f17d Add more styling for HistorySidebar + add README for working w/ icons 2025-09-16 11:21:13 -07:00
Raunak Bhagat
d5a222925a Add icons (as raw TSX) 2025-09-16 09:51:25 -07:00
Raunak Bhagat
d5ef928782 Add icons 2025-09-16 09:12:12 -07:00
Raunak Bhagat
6963d78f8e Fix more build errors? 2025-09-15 17:39:24 -07:00
Raunak Bhagat
d3ef2b8c17 Fix build errors 2025-09-15 17:28:34 -07:00
Raunak Bhagat
70f4162ea8 Update name 2025-09-15 17:17:09 -07:00
Raunak Bhagat
883f52d332 Update component names 2025-09-15 17:09:13 -07:00
Raunak Bhagat
f8fd83c883 Clean up sidebar 2025-09-15 15:56:45 -07:00
Raunak Bhagat
d2bf0c0c5f Update token-context bar 2025-09-15 11:43:54 -07:00
Raunak Bhagat
5d598c2d22 Add more colour fixes to Modal 2025-09-15 11:13:45 -07:00
Raunak Bhagat
9dc0e97302 Merge branch 'main' into colours 2025-09-15 09:45:32 -07:00
Raunak Bhagat
048b2a6b39 Edit LLMPopover and add border-radii 2025-09-15 09:43:06 -07:00
Raunak Bhagat
7dd3cecf67 Edit UserDropdown colours 2025-09-15 09:15:56 -07:00
Raunak Bhagat
82abe28986 Update more colours 2025-09-14 20:37:33 -07:00
Raunak Bhagat
a0575e6a00 Update colours for sidebar 2025-09-14 20:24:25 -07:00
Raunak Bhagat
0c5bf5b3ed Add all colours from Figma 2025-09-11 13:34:54 -07:00
Raunak Bhagat
492117d910 Edit .gitignore 2025-09-11 12:32:53 -07:00
468 changed files with 25785 additions and 8710 deletions

View File

@@ -6,9 +6,6 @@
[Describe the tests you ran to verify your changes]
## Backporting (check the box to trigger backport action)
## Additional Options
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] [Optional] Override Linear Check

View File

@@ -0,0 +1,24 @@
name: Check Lazy Imports
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
jobs:
check-lazy-imports:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Check lazy imports
run: python3 backend/scripts/check_lazy_imports.py

View File

@@ -27,6 +27,7 @@ jobs:
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add keda https://kedacore.github.io/charts
helm repo update
- name: Build chart dependencies
@@ -46,4 +47,4 @@ jobs:
charts_dir: deployment/helm/charts
branch: gh-pages
commit_username: ${{ github.actor }}
commit_email: ${{ github.actor }}@users.noreply.github.com
commit_email: ${{ github.actor }}@users.noreply.github.com

View File

@@ -23,6 +23,7 @@ env:
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
jobs:
discover-test-dirs:
@@ -42,8 +43,8 @@ jobs:
external-dependency-unit-tests:
needs: discover-test-dirs
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
# Use larger runner with more resources for Vespa
runs-on: [runs-on, runner=16cpu-linux-x64, "run-id=${{ github.run_id }}"]
strategy:
fail-fast: false
@@ -52,6 +53,7 @@ jobs:
env:
PYTHONPATH: ./backend
MODEL_SERVER_HOST: "disabled"
steps:
- name: Checkout code
@@ -77,19 +79,30 @@ jobs:
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
- name: Wait for services
run: |
echo "Waiting for services to be ready..."
sleep 30
# Wait for Vespa specifically
echo "Waiting for Vespa to be ready..."
timeout 300 bash -c 'until curl -f -s http://localhost:8081/ApplicationStatus > /dev/null 2>&1; do echo "Vespa not ready, waiting..."; sleep 10; done' || echo "Vespa timeout - continuing anyway"
echo "Services should be ready now"
- name: Run migrations
run: |
cd backend
# Run migrations to head
alembic upgrade head
alembic heads --verbose
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \

View File

@@ -169,6 +169,7 @@ jobs:
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_file_processing.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml

View File

@@ -130,6 +130,7 @@ jobs:
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -189,6 +190,7 @@ jobs:
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
integration-tests:
needs:
@@ -230,9 +232,9 @@ jobs:
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from private registry
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
@@ -257,7 +259,7 @@ jobs:
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.dev.yml -p onyx-stack up \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
@@ -273,7 +275,7 @@ jobs:
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
docker logs -f onyx-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -317,7 +319,7 @@ jobs:
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx-stack_default \
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
@@ -354,13 +356,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -374,7 +376,7 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
docker compose down -v
multitenant-tests:
@@ -405,9 +407,9 @@ jobs:
- name: Pull Docker images
run: |
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
wait
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }} onyxdotapp/onyx-backend:test
docker tag ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }} onyxdotapp/onyx-model-server:test
@@ -423,7 +425,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up \
docker compose -f docker-compose.multitenant-dev.yml up \
relational_db \
index \
cache \
@@ -438,7 +440,7 @@ jobs:
- name: Wait for service to be ready (multi-tenant)
run: |
echo "Starting wait-for-service script for multi-tenant..."
docker logs -f onyx-stack-api_server-1 &
docker logs -f onyx-api_server-1 &
start_time=$(date +%s)
timeout=300
while true; do
@@ -464,7 +466,7 @@ jobs:
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network onyx-stack_default \
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
@@ -493,13 +495,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
docker compose -f docker-compose.multitenant-dev.yml logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true
- name: Dump all-container logs (multi-tenant)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
docker compose -f docker-compose.multitenant-dev.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true
- name: Upload logs (multi-tenant)
if: always()
@@ -512,7 +514,7 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
docker compose -f docker-compose.multitenant-dev.yml down -v
required:
runs-on: blacksmith-2vcpu-ubuntu-2404-arm

View File

@@ -127,6 +127,7 @@ jobs:
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -186,6 +187,7 @@ jobs:
platforms: linux/arm64
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
integration-tests-mit:
needs:
@@ -228,9 +230,9 @@ jobs:
# Pull all images from registry in parallel
echo "Pulling Docker images in parallel..."
# Pull images from private registry
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-model-server:test-${{ github.run_id }}) &
(docker pull --platform linux/arm64 ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}) &
# Wait for all background jobs to complete
wait
@@ -253,7 +255,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
@@ -269,7 +271,7 @@ jobs:
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
docker logs -f onyx-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -314,7 +316,7 @@ jobs:
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx-stack_default \
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
@@ -351,13 +353,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -371,7 +373,7 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
docker compose down -v
required:

View File

@@ -189,14 +189,14 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
docker logs -f onyx-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -244,7 +244,7 @@ jobs:
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
docker compose logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
@@ -257,7 +257,7 @@ jobs:
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
docker compose down -v
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.

View File

@@ -96,6 +96,13 @@ env:
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
# Bitbucket
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -77,7 +77,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
@@ -132,7 +132,7 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -145,5 +145,5 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
docker compose -f docker-compose.model-server-test.yml down -v

View File

@@ -31,12 +31,14 @@ jobs:
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

2
.gitignore vendored
View File

@@ -17,6 +17,7 @@ backend/tests/regression/answer_quality/test_data.json
backend/tests/regression/search_quality/eval-*
backend/tests/regression/search_quality/search_eval_config.yaml
backend/tests/regression/search_quality/*.json
backend/onyx/evals/data/
*.log
# secret files
@@ -28,6 +29,7 @@ settings.json
/deployment/data/nginx/app.conf
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
*.egg-info
# Local .terraform directories
**/.terraform/*

8
.mcp.json.template Normal file
View File

@@ -0,0 +1,8 @@
{
"mcpServers": {
"onyx-mcp": {
"type": "http",
"url": "http://localhost:8000/mcp"
}
}
}

View File

@@ -37,6 +37,15 @@ repos:
additional_dependencies:
- prettier
- repo: local
hooks:
- id: check-lazy-imports
name: Check lazy imports are not directly imported
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/.*\.py$
pass_filenames: false
# We would like to have a mypy pre-commit hook, but due to the fact that
# pre-commit runs in it's own isolated environment, we would need to install
# and keep in sync all dependencies so mypy has access to the appropriate type

View File

@@ -10,7 +10,7 @@ SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout
LOG_DANSWER_MODEL_INTERACTIONS=True
LOG_ONYX_MODEL_INTERACTIONS=True
# More verbose logging
LOG_LEVEL=debug
@@ -39,8 +39,8 @@ FAST_GEN_AI_MODEL_VERSION=gpt-4o
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
#DANSWER_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
#DANSWER_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
# Python stuff

File diff suppressed because it is too large Load Diff

View File

@@ -4,14 +4,14 @@ This file provides guidance to Codex when working with code in this repository.
## KEY NOTES
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
outside of those directories.

View File

@@ -4,14 +4,14 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
## KEY NOTES
- If you run into any missing python dependency errors, try running your command with `workon onyx &&` in front
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
to assume the python venv.
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
`a`. The app can be accessed at `http://localhost:3000`.
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
make sure we see logs coming out from the relevant service.
- To connect to the Postgres database, use: `docker exec -it onyx-stack-relational_db-1 psql -U postgres -c "<SQL>"`
- To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c "<SQL>"`
- When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona`
- Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries
outside of those directories.

View File

@@ -84,10 +84,6 @@ python -m venv .venv
source .venv/bin/activate
```
> **Note:**
> This virtual environment MUST NOT be set up WITHIN the onyx directory if you plan on using mypy within certain IDEs.
> For simplicity, we recommend setting up the virtual environment outside of the onyx directory.
_For Windows, activate the virtual environment using Command Prompt:_
```bash
@@ -175,7 +171,7 @@ You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
docker compose up -d index relational_db cache minio
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
@@ -257,7 +253,7 @@ You can run the full Onyx application stack from pre-built images including all
Navigate to `onyx/deployment/docker_compose` and run:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
docker compose up -d
```
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
@@ -265,7 +261,7 @@ After Docker pulls and starts these containers, navigate to `http://localhost:30
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
docker compose up -d --build
```

134
README.md
View File

@@ -1,117 +1,103 @@
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
</h2>
<p align="center">
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
<p align="center">Open Source AI Platform</p>
<p align="center">
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
</a>
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
</a>
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/website?url=https://www.onyx.app&up_message=visit&up_color=blue" alt="Documentation">
</a>
<a href="https://github.com/onyx-dot-app/onyx/blob/main/LICENSE" target="_blank">
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<h3>Feature Highlights</h3>
**[Onyx](https://www.onyx.app/)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.
**Deep research over your team's knowledge:**
Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep Research, Connectors to 40+ knowledge sources, and more.
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
> [!TIP]
> Run Onyx with one command (or see deployment section below):
> ```
> curl -fsSL https://raw.githubusercontent.com/onyx-dot-app/onyx/main/deployment/docker_compose/install.sh > install.sh && chmod +x install.sh && ./install.sh
> ```
**Use Onyx as a secure AI Chat with any LLM:**
****
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
## ⭐ Features
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge and actions.
- **🌍 Web Search:** Browse the web with Google PSE, Exa, and Serper as well as an in-house scraper or Firecrawl.
- **🔍 RAG:** Best in class hybrid-search + knowledge graph for uploaded files and ingested documents from connectors.
- **🔄 Connectors:** Pull knowledge, metadata, and access information from over 40 applications.
- **🔬 Deep Research:** Get in depth answers with an agentic multi-step search.
- **▶️ Actions & MCP:** Give AI Agents the ability to interact with external systems.
- **💻 Code Interpreter:** Execute code to analyze data, render graphs and create files.
- **🎨 Image Generation:** Generate images based on user prompts.
- **👥 Collaboration:** Chat sharing, feedback gathering, user management, usage analytics, and more.
Onyx works with all LLMs (like OpenAI, Anthropic, Gemini, etc.) and self-hosted LLMs (like Ollama, vLLM, etc.)
To learn more about the features, check out our [documentation](https://docs.onyx.app/welcome)!
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
## 🚀 Deployment
Onyx supports deployments in Docker, Kubernetes, Terraform, along with guides for major cloud providers.
See guides below:
- [Docker](https://docs.onyx.app/deployment/local/docker) or [Quickstart](https://docs.onyx.app/deployment/getting_started/quickstart) (best for most users)
- [Kubernetes](https://docs.onyx.app/deployment/local/kubernetes) (best for large teams)
- [Terraform](https://docs.onyx.app/deployment/local/terraform) (best for teams already using Terraform)
- Cloud specific guides (best if specifically using [AWS EKS](https://docs.onyx.app/deployment/cloud/aws/eks), [Azure VMs](https://docs.onyx.app/deployment/cloud/azure), etc.)
> [!TIP]
> **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/deployment/getting_started/quickstart) to learn more.
## 🔍 Other Notable Benefits
Onyx is built for teams of all sizes, from individual users to the largest global enterprises.
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
- **Enterprise Search**: far more than simple RAG, Onyx has custom indexing and retrieval that remains performant and accurate for scales of up to tens of millions of documents.
- **Security**: SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- **Management UI**: different user roles such as basic, curator, and admin.
- **Document Permissioning**: mirrors user access from external apps for RAG use cases.
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🚧 Roadmap
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
To see ongoing and upcoming projects, check out our [roadmap](https://github.com/orgs/onyx-dot-app/projects/2)!
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
- Google Drive
- Confluence
- Slack
- Gmail
- Salesforce
- Microsoft Sharepoint
- Github
- Jira
- Zendesk
- Gong
- Microsoft Teams
- Dropbox
- Local Files
- Websites
- And more ...
See the full list [here](https://docs.onyx.app/admin/connectors/overview).
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Community Edition (CE) is available freely under the MIT license.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 👪 Community
Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**!
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -0,0 +1,389 @@
"""Migration 2: User file data preparation and backfill
Revision ID: 0cd424f32b1d
Revises: 9b66d3156fc6
Create Date: 2025-09-22 09:44:42.727034
This migration populates the new columns added in migration 1.
It prepares data for the UUID transition and relationship migration.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "0cd424f32b1d"
down_revision = "9b66d3156fc6"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Populate new columns with data."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# === Step 1: Populate user_file.new_id ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id:
logger.info("Populating user_file.new_id with UUIDs...")
# Count rows needing UUIDs
null_count = bind.execute(
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
).scalar_one()
if null_count > 0:
logger.info(f"Generating UUIDs for {null_count} user_file records...")
# Populate in batches to avoid long locks
batch_size = 10000
total_updated = 0
while True:
result = bind.execute(
text(
"""
UPDATE user_file
SET new_id = gen_random_uuid()
WHERE new_id IS NULL
AND id IN (
SELECT id FROM user_file
WHERE new_id IS NULL
LIMIT :batch_size
)
"""
),
{"batch_size": batch_size},
)
updated = result.rowcount
total_updated += updated
if updated < batch_size:
break
logger.info(f" Updated {total_updated}/{null_count} records...")
logger.info(f"Generated UUIDs for {total_updated} user_file records")
# Verify all records have UUIDs
remaining_null = bind.execute(
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
).scalar_one()
if remaining_null > 0:
raise Exception(
f"Failed to populate all user_file.new_id values ({remaining_null} NULL)"
)
# Lock down the column
op.alter_column("user_file", "new_id", nullable=False)
op.alter_column("user_file", "new_id", server_default=None)
logger.info("Locked down user_file.new_id column")
# === Step 2: Populate persona__user_file.user_file_id_uuid ===
persona_user_file_columns = [
col["name"] for col in inspector.get_columns("persona__user_file")
]
if has_new_id and "user_file_id_uuid" in persona_user_file_columns:
logger.info("Populating persona__user_file.user_file_id_uuid...")
# Count rows needing update
null_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM persona__user_file
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
"""
)
).scalar_one()
if null_count > 0:
logger.info(f"Updating {null_count} persona__user_file records...")
# Update in batches
batch_size = 10000
total_updated = 0
while True:
result = bind.execute(
text(
"""
UPDATE persona__user_file p
SET user_file_id_uuid = uf.new_id
FROM user_file uf
WHERE p.user_file_id = uf.id
AND p.user_file_id_uuid IS NULL
AND p.persona_id IN (
SELECT persona_id
FROM persona__user_file
WHERE user_file_id_uuid IS NULL
LIMIT :batch_size
)
"""
),
{"batch_size": batch_size},
)
updated = result.rowcount
total_updated += updated
if updated < batch_size:
break
logger.info(f" Updated {total_updated}/{null_count} records...")
logger.info(f"Updated {total_updated} persona__user_file records")
# Verify all records are populated
remaining_null = bind.execute(
text(
"""
SELECT COUNT(*) FROM persona__user_file
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
"""
)
).scalar_one()
if remaining_null > 0:
raise Exception(
f"Failed to populate all persona__user_file.user_file_id_uuid values ({remaining_null} NULL)"
)
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=False)
logger.info("Locked down persona__user_file.user_file_id_uuid column")
# === Step 3: Create user_project records from chat_folder ===
if "chat_folder" in inspector.get_table_names():
logger.info("Creating user_project records from chat_folder...")
result = bind.execute(
text(
"""
INSERT INTO user_project (user_id, name)
SELECT cf.user_id, cf.name
FROM chat_folder cf
WHERE NOT EXISTS (
SELECT 1
FROM user_project up
WHERE up.user_id = cf.user_id AND up.name = cf.name
)
"""
)
)
logger.info(f"Created {result.rowcount} user_project records from chat_folder")
# === Step 4: Populate chat_session.project_id ===
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "folder_id" in chat_session_columns and "project_id" in chat_session_columns:
logger.info("Populating chat_session.project_id...")
# Count sessions needing update
null_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE project_id IS NULL AND folder_id IS NOT NULL
"""
)
).scalar_one()
if null_count > 0:
logger.info(f"Updating {null_count} chat_session records...")
result = bind.execute(
text(
"""
UPDATE chat_session cs
SET project_id = up.id
FROM chat_folder cf
JOIN user_project up ON up.user_id = cf.user_id AND up.name = cf.name
WHERE cs.folder_id = cf.id AND cs.project_id IS NULL
"""
)
)
logger.info(f"Updated {result.rowcount} chat_session records")
# Verify all records are populated
remaining_null = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE project_id IS NULL AND folder_id IS NOT NULL
"""
)
).scalar_one()
if remaining_null > 0:
logger.warning(
f"Warning: {remaining_null} chat_session records could not be mapped to projects"
)
# === Step 5: Update plaintext FileRecord IDs/display names to UUID scheme ===
# Prior to UUID migration, plaintext cache files were stored with file_id like 'plain_text_<int_id>'.
# After migration, we use 'plaintext_<uuid>' (note the name change to 'plaintext_').
# This step remaps existing FileRecord rows to the new naming while preserving object_key/bucket.
logger.info("Updating plaintext FileRecord ids and display names to UUID scheme...")
# Count legacy plaintext records that can be mapped to UUID user_file ids
count_query = text(
"""
SELECT COUNT(*)
FROM file_record fr
JOIN user_file uf ON fr.file_id = CONCAT('plaintext_', uf.id::text)
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
"""
)
legacy_count = bind.execute(count_query).scalar_one()
if legacy_count and legacy_count > 0:
logger.info(f"Found {legacy_count} legacy plaintext file records to update")
# Update display_name first for readability (safe regardless of rename)
bind.execute(
text(
"""
UPDATE file_record fr
SET display_name = CONCAT('Plaintext for user file ', uf.new_id::text)
FROM user_file uf
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
"""
)
)
# Remap file_id from 'plaintext_<int>' -> 'plaintext_<uuid>' using transitional new_id
# Use a single UPDATE ... WHERE file_id LIKE 'plain_text_%'
# and ensure it aligns to existing user_file ids to avoid renaming unrelated rows
result = bind.execute(
text(
"""
UPDATE file_record fr
SET file_id = CONCAT('plaintext_', uf.new_id::text)
FROM user_file uf
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
"""
)
)
logger.info(
f"Updated {result.rowcount} plaintext file_record ids to UUID scheme"
)
# === Step 6: Ensure document_id_migrated default TRUE and backfill existing FALSE ===
# New records should default to migrated=True so the migration task won't run for them.
# Existing rows that had a legacy document_id should be marked as not migrated to be processed.
# Backfill existing records: if document_id is not null, set to FALSE
bind.execute(
text(
"""
UPDATE user_file
SET document_id_migrated = FALSE
WHERE document_id IS NOT NULL
"""
)
)
# === Step 7: Backfill user_file.status from index_attempt ===
logger.info("Backfilling user_file.status from index_attempt...")
# Update user_file status based on latest index attempt
# Using CTEs instead of temp tables for asyncpg compatibility
result = bind.execute(
text(
"""
WITH latest_attempt AS (
SELECT DISTINCT ON (ia.connector_credential_pair_id)
ia.connector_credential_pair_id,
ia.status
FROM index_attempt ia
ORDER BY ia.connector_credential_pair_id, ia.time_updated DESC
),
uf_to_ccp AS (
SELECT DISTINCT uf.id AS uf_id, ccp.id AS cc_pair_id
FROM user_file uf
JOIN document_by_connector_credential_pair dcc
ON dcc.id = REPLACE(uf.document_id, 'USER_FILE_CONNECTOR__', 'FILE_CONNECTOR__')
JOIN connector_credential_pair ccp
ON ccp.connector_id = dcc.connector_id
AND ccp.credential_id = dcc.credential_id
)
UPDATE user_file uf
SET status = CASE
WHEN la.status IN ('NOT_STARTED', 'IN_PROGRESS') THEN 'PROCESSING'
WHEN la.status = 'SUCCESS' THEN 'COMPLETED'
ELSE 'FAILED'
END
FROM uf_to_ccp ufc
LEFT JOIN latest_attempt la
ON la.connector_credential_pair_id = ufc.cc_pair_id
WHERE uf.id = ufc.uf_id
AND uf.status = 'PROCESSING'
"""
)
)
logger.info(f"Updated status for {result.rowcount} user_file records")
logger.info("Migration 2 (data preparation) completed successfully")
def downgrade() -> None:
"""Reset populated data to allow clean downgrade of schema."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of data preparation...")
# Reset user_file columns to allow nulls before data removal
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "new_id" in columns:
op.alter_column(
"user_file",
"new_id",
nullable=True,
server_default=sa.text("gen_random_uuid()"),
)
# Optionally clear the data
# bind.execute(text("UPDATE user_file SET new_id = NULL"))
logger.info("Reset user_file.new_id to nullable")
# Reset persona__user_file.user_file_id_uuid
if "persona__user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
if "user_file_id_uuid" in columns:
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=True)
# Optionally clear the data
# bind.execute(text("UPDATE persona__user_file SET user_file_id_uuid = NULL"))
logger.info("Reset persona__user_file.user_file_id_uuid to nullable")
# Note: We don't delete user_project records or reset chat_session.project_id
# as these might be in use and can be handled by the schema downgrade
# Reset user_file.status to default
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "status" in columns:
bind.execute(text("UPDATE user_file SET status = 'PROCESSING'"))
logger.info("Reset user_file.status to default")
logger.info("Downgrade completed successfully")

View File

@@ -0,0 +1,261 @@
"""Migration 3: User file relationship migration
Revision ID: 16c37a30adf2
Revises: 0cd424f32b1d
Create Date: 2025-09-22 09:47:34.175596
This migration converts folder-based relationships to project-based relationships.
It migrates persona__user_folder to persona__user_file and populates project__user_file.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "16c37a30adf2"
down_revision = "0cd424f32b1d"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Migrate folder-based relationships to project-based relationships."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# === Step 1: Migrate persona__user_folder to persona__user_file ===
table_names = inspector.get_table_names()
if "persona__user_folder" in table_names and "user_file" in table_names:
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id and "folder_id" in user_file_columns:
logger.info(
"Migrating persona__user_folder relationships to persona__user_file..."
)
# Count relationships to migrate (asyncpg-compatible)
count_query = text(
"""
SELECT COUNT(*)
FROM (
SELECT DISTINCT puf.persona_id, uf.id
FROM persona__user_folder puf
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
WHERE NOT EXISTS (
SELECT 1
FROM persona__user_file p2
WHERE p2.persona_id = puf.persona_id
AND p2.user_file_id = uf.id
)
) AS distinct_pairs
"""
)
to_migrate = bind.execute(count_query).scalar_one()
if to_migrate > 0:
logger.info(f"Creating {to_migrate} persona-file relationships...")
# Migrate in batches to avoid memory issues
batch_size = 10000
total_inserted = 0
while True:
# Insert batch directly using subquery (asyncpg compatible)
result = bind.execute(
text(
"""
INSERT INTO persona__user_file (persona_id, user_file_id, user_file_id_uuid)
SELECT DISTINCT puf.persona_id, uf.id as file_id, uf.new_id
FROM persona__user_folder puf
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
WHERE NOT EXISTS (
SELECT 1
FROM persona__user_file p2
WHERE p2.persona_id = puf.persona_id
AND p2.user_file_id = uf.id
)
LIMIT :batch_size
"""
),
{"batch_size": batch_size},
)
inserted = result.rowcount
total_inserted += inserted
if inserted < batch_size:
break
logger.info(
f" Migrated {total_inserted}/{to_migrate} relationships..."
)
logger.info(
f"Created {total_inserted} persona__user_file relationships"
)
# === Step 2: Add foreign key for chat_session.project_id ===
chat_session_fks = inspector.get_foreign_keys("chat_session")
fk_exists = any(
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
)
if not fk_exists:
logger.info("Adding foreign key constraint for chat_session.project_id...")
op.create_foreign_key(
"fk_chat_session_project_id",
"chat_session",
"user_project",
["project_id"],
["id"],
)
logger.info("Added foreign key constraint")
# === Step 3: Populate project__user_file from user_file.folder_id ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id and "folder_id" in user_file_columns:
logger.info("Populating project__user_file from folder relationships...")
# Count relationships to create
count_query = text(
"""
SELECT COUNT(*)
FROM user_file uf
WHERE uf.folder_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM project__user_file puf
WHERE puf.project_id = uf.folder_id
AND puf.user_file_id = uf.new_id
)
"""
)
to_create = bind.execute(count_query).scalar_one()
if to_create > 0:
logger.info(f"Creating {to_create} project-file relationships...")
# Insert in batches
batch_size = 10000
total_inserted = 0
while True:
result = bind.execute(
text(
"""
INSERT INTO project__user_file (project_id, user_file_id)
SELECT uf.folder_id, uf.new_id
FROM user_file uf
WHERE uf.folder_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM project__user_file puf
WHERE puf.project_id = uf.folder_id
AND puf.user_file_id = uf.new_id
)
LIMIT :batch_size
ON CONFLICT (project_id, user_file_id) DO NOTHING
"""
),
{"batch_size": batch_size},
)
inserted = result.rowcount
total_inserted += inserted
if inserted < batch_size:
break
logger.info(f" Created {total_inserted}/{to_create} relationships...")
logger.info(f"Created {total_inserted} project__user_file relationships")
# === Step 4: Create index on chat_session.project_id ===
try:
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
except Exception:
indexes = []
if "ix_chat_session_project_id" not in indexes:
logger.info("Creating index on chat_session.project_id...")
op.create_index(
"ix_chat_session_project_id", "chat_session", ["project_id"], unique=False
)
logger.info("Created index")
logger.info("Migration 3 (relationship migration) completed successfully")
def downgrade() -> None:
"""Remove migrated relationships and constraints."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of relationship migration...")
# Drop index on chat_session.project_id
try:
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
if "ix_chat_session_project_id" in indexes:
op.drop_index("ix_chat_session_project_id", "chat_session")
logger.info("Dropped index on chat_session.project_id")
except Exception:
pass
# Drop foreign key constraint
try:
chat_session_fks = inspector.get_foreign_keys("chat_session")
fk_exists = any(
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
)
if fk_exists:
op.drop_constraint(
"fk_chat_session_project_id", "chat_session", type_="foreignkey"
)
logger.info("Dropped foreign key constraint on chat_session.project_id")
except Exception:
pass
# Clear project__user_file relationships (but keep the table for migration 1 to handle)
if "project__user_file" in inspector.get_table_names():
result = bind.execute(text("DELETE FROM project__user_file"))
logger.info(f"Cleared {result.rowcount} records from project__user_file")
# Remove migrated persona__user_file relationships
# Only remove those that came from folder relationships
if all(
table in inspector.get_table_names()
for table in ["persona__user_file", "persona__user_folder", "user_file"]
):
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
if "folder_id" in user_file_columns:
result = bind.execute(
text(
"""
DELETE FROM persona__user_file puf
WHERE EXISTS (
SELECT 1
FROM user_file uf
JOIN persona__user_folder puf2
ON puf2.user_folder_id = uf.folder_id
WHERE puf.persona_id = puf2.persona_id
AND puf.user_file_id = uf.id
)
"""
)
)
logger.info(
f"Removed {result.rowcount} migrated persona__user_file relationships"
)
logger.info("Downgrade completed successfully")

View File

@@ -0,0 +1,218 @@
"""Migration 6: User file schema cleanup
Revision ID: 2b75d0a8ffcb
Revises: 3a78dba1080a
Create Date: 2025-09-22 10:09:26.375377
This migration removes legacy columns and tables after data migration is complete.
It should only be run after verifying all data has been successfully migrated.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "2b75d0a8ffcb"
down_revision = "3a78dba1080a"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Remove legacy columns and tables."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting schema cleanup...")
# === Step 1: Verify data migration is complete ===
logger.info("Verifying data migration completion...")
# Check if any chat sessions still have folder_id references
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "folder_id" in chat_session_columns:
orphaned_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE folder_id IS NOT NULL AND project_id IS NULL
"""
)
).scalar_one()
if orphaned_count > 0:
logger.warning(
f"WARNING: {orphaned_count} chat_session records still have "
f"folder_id without project_id. Proceeding anyway."
)
# === Step 2: Drop chat_session.folder_id ===
if "folder_id" in chat_session_columns:
logger.info("Dropping chat_session.folder_id...")
# Drop foreign key constraint first
op.execute(
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk"
)
# Drop the column
op.drop_column("chat_session", "folder_id")
logger.info("Dropped chat_session.folder_id")
# === Step 3: Drop persona__user_folder table ===
if "persona__user_folder" in inspector.get_table_names():
logger.info("Dropping persona__user_folder table...")
# Check for any remaining data
remaining = bind.execute(
text("SELECT COUNT(*) FROM persona__user_folder")
).scalar_one()
if remaining > 0:
logger.warning(
f"WARNING: Dropping persona__user_folder with {remaining} records"
)
op.drop_table("persona__user_folder")
logger.info("Dropped persona__user_folder table")
# === Step 4: Drop chat_folder table ===
if "chat_folder" in inspector.get_table_names():
logger.info("Dropping chat_folder table...")
# Check for any remaining data
remaining = bind.execute(text("SELECT COUNT(*) FROM chat_folder")).scalar_one()
if remaining > 0:
logger.warning(f"WARNING: Dropping chat_folder with {remaining} records")
op.drop_table("chat_folder")
logger.info("Dropped chat_folder table")
# === Step 5: Drop user_file legacy columns ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
# Drop folder_id
if "folder_id" in user_file_columns:
logger.info("Dropping user_file.folder_id...")
op.drop_column("user_file", "folder_id")
logger.info("Dropped user_file.folder_id")
# Drop cc_pair_id (already handled in migration 5, but be sure)
if "cc_pair_id" in user_file_columns:
logger.info("Dropping user_file.cc_pair_id...")
# Drop any remaining foreign key constraints
bind.execute(
text(
"""
DO $$
DECLARE r RECORD;
BEGIN
FOR r IN (
SELECT conname
FROM pg_constraint c
JOIN pg_class t ON c.conrelid = t.oid
WHERE c.contype = 'f'
AND t.relname = 'user_file'
AND EXISTS (
SELECT 1 FROM pg_attribute a
WHERE a.attrelid = t.oid
AND a.attname = 'cc_pair_id'
)
) LOOP
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
END LOOP;
END$$;
"""
)
)
op.drop_column("user_file", "cc_pair_id")
logger.info("Dropped user_file.cc_pair_id")
# === Step 6: Clean up any remaining constraints ===
logger.info("Cleaning up remaining constraints...")
# Drop any unique constraints on removed columns
op.execute(
"ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_cc_pair_id_key"
)
logger.info("Migration 6 (schema cleanup) completed successfully")
logger.info("Legacy schema has been fully removed")
def downgrade() -> None:
"""Recreate dropped columns and tables (structure only, no data)."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.warning("Downgrading schema cleanup - recreating structure only, no data!")
# Recreate user_file columns
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "cc_pair_id" not in columns:
op.add_column(
"user_file", sa.Column("cc_pair_id", sa.Integer(), nullable=True)
)
if "folder_id" not in columns:
op.add_column(
"user_file", sa.Column("folder_id", sa.Integer(), nullable=True)
)
# Recreate chat_folder table
if "chat_folder" not in inspector.get_table_names():
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"], ["user.id"], name="chat_folder_user_fk"
),
)
# Recreate persona__user_folder table
if "persona__user_folder" not in inspector.get_table_names():
op.create_table(
"persona__user_folder",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("user_folder_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("persona_id", "user_folder_id"),
sa.ForeignKeyConstraint(["persona_id"], ["persona.id"]),
sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]),
)
# Add folder_id back to chat_session
if "chat_session" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("chat_session")]
if "folder_id" not in columns:
op.add_column(
"chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)
)
# Add foreign key if chat_folder exists
if "chat_folder" in inspector.get_table_names():
op.create_foreign_key(
"chat_session_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
)
logger.info("Downgrade completed - structure recreated but data is lost")

View File

@@ -0,0 +1,295 @@
"""Migration 5: User file legacy data cleanup
Revision ID: 3a78dba1080a
Revises: 7cc3fcc116c1
Create Date: 2025-09-22 10:04:27.986294
This migration removes legacy user-file documents and connector_credential_pairs.
It performs bulk deletions of obsolete data after the UUID migration.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
from sqlalchemy import text
import logging
from typing import List
import uuid
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "3a78dba1080a"
down_revision = "7cc3fcc116c1"
branch_labels = None
depends_on = None
def batch_delete(
bind: sa.engine.Connection,
table_name: str,
id_column: str,
ids: List[str | int | uuid.UUID],
batch_size: int = 1000,
id_type: str = "int",
) -> int:
"""Delete records in batches to avoid memory issues and timeouts."""
total_count = len(ids)
if total_count == 0:
return 0
logger.info(
f"Starting batch deletion of {total_count} records from {table_name}..."
)
# Determine appropriate ARRAY type
if id_type == "uuid":
array_type = psql.ARRAY(psql.UUID(as_uuid=True))
elif id_type == "int":
array_type = psql.ARRAY(sa.Integer())
else:
array_type = psql.ARRAY(sa.String())
total_deleted = 0
failed_batches = []
for i in range(0, total_count, batch_size):
batch_ids = ids[i : i + batch_size]
try:
stmt = text(
f"DELETE FROM {table_name} WHERE {id_column} = ANY(:ids)"
).bindparams(sa.bindparam("ids", value=batch_ids, type_=array_type))
result = bind.execute(stmt)
total_deleted += result.rowcount
# Log progress every 10 batches or at completion
batch_num = (i // batch_size) + 1
if batch_num % 10 == 0 or i + batch_size >= total_count:
logger.info(
f" Deleted {min(i + batch_size, total_count)}/{total_count} records "
f"({total_deleted} actual) from {table_name}"
)
except Exception as e:
logger.error(f"Failed to delete batch {(i // batch_size) + 1}: {e}")
failed_batches.append((i, min(i + batch_size, total_count)))
if failed_batches:
logger.warning(
f"Failed to delete {len(failed_batches)} batches from {table_name}. "
f"Total deleted: {total_deleted}/{total_count}"
)
# Fail the migration to avoid silently succeeding on partial cleanup
raise RuntimeError(
f"Batch deletion failed for {table_name}: "
f"{len(failed_batches)} failed batches out of "
f"{(total_count + batch_size - 1) // batch_size}."
)
return total_deleted
def upgrade() -> None:
"""Remove legacy user-file documents and connector_credential_pairs."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting legacy data cleanup...")
# === Step 1: Identify and delete user-file documents ===
logger.info("Identifying user-file documents to delete...")
# Get document IDs to delete
doc_rows = bind.execute(
text(
"""
SELECT DISTINCT dcc.id AS document_id
FROM document_by_connector_credential_pair dcc
JOIN connector_credential_pair u
ON u.connector_id = dcc.connector_id
AND u.credential_id = dcc.credential_id
WHERE u.is_user_file IS TRUE
"""
)
).fetchall()
doc_ids = [r[0] for r in doc_rows]
if doc_ids:
logger.info(f"Found {len(doc_ids)} user-file documents to delete")
# Delete dependent rows first
tables_to_clean = [
("document_retrieval_feedback", "document_id"),
("document__tag", "document_id"),
("chunk_stats", "document_id"),
]
for table_name, column_name in tables_to_clean:
if table_name in inspector.get_table_names():
# document_id is a string in these tables
deleted = batch_delete(
bind, table_name, column_name, doc_ids, id_type="str"
)
logger.info(f"Deleted {deleted} records from {table_name}")
# Delete document_by_connector_credential_pair entries
deleted = batch_delete(
bind, "document_by_connector_credential_pair", "id", doc_ids, id_type="str"
)
logger.info(f"Deleted {deleted} document_by_connector_credential_pair records")
# Delete documents themselves
deleted = batch_delete(bind, "document", "id", doc_ids, id_type="str")
logger.info(f"Deleted {deleted} document records")
else:
logger.info("No user-file documents found to delete")
# === Step 2: Clean up user-file connector_credential_pairs ===
logger.info("Cleaning up user-file connector_credential_pairs...")
# Get cc_pair IDs
cc_pair_rows = bind.execute(
text(
"""
SELECT id AS cc_pair_id
FROM connector_credential_pair
WHERE is_user_file IS TRUE
"""
)
).fetchall()
cc_pair_ids = [r[0] for r in cc_pair_rows]
if cc_pair_ids:
logger.info(
f"Found {len(cc_pair_ids)} user-file connector_credential_pairs to clean up"
)
# Delete related records
tables_to_clean = [
("index_attempt", "connector_credential_pair_id"),
("background_error", "cc_pair_id"),
("document_set__connector_credential_pair", "connector_credential_pair_id"),
("user_group__connector_credential_pair", "cc_pair_id"),
]
for table_name, column_name in tables_to_clean:
if table_name in inspector.get_table_names():
deleted = batch_delete(
bind, table_name, column_name, cc_pair_ids, id_type="int"
)
logger.info(f"Deleted {deleted} records from {table_name}")
# === Step 3: Identify connectors and credentials to delete ===
logger.info("Identifying orphaned connectors and credentials...")
# Get connectors used only by user-file cc_pairs
connector_rows = bind.execute(
text(
"""
SELECT DISTINCT ccp.connector_id
FROM connector_credential_pair ccp
WHERE ccp.is_user_file IS TRUE
AND ccp.connector_id != 0 -- Exclude system default
AND NOT EXISTS (
SELECT 1
FROM connector_credential_pair c2
WHERE c2.connector_id = ccp.connector_id
AND c2.is_user_file IS NOT TRUE
)
"""
)
).fetchall()
userfile_only_connector_ids = [r[0] for r in connector_rows]
# Get credentials used only by user-file cc_pairs
credential_rows = bind.execute(
text(
"""
SELECT DISTINCT ccp.credential_id
FROM connector_credential_pair ccp
WHERE ccp.is_user_file IS TRUE
AND ccp.credential_id != 0 -- Exclude public/default
AND NOT EXISTS (
SELECT 1
FROM connector_credential_pair c2
WHERE c2.credential_id = ccp.credential_id
AND c2.is_user_file IS NOT TRUE
)
"""
)
).fetchall()
userfile_only_credential_ids = [r[0] for r in credential_rows]
# === Step 4: Delete the cc_pairs themselves ===
if cc_pair_ids:
# Remove FK dependency from user_file first
bind.execute(
text(
"""
DO $$
DECLARE r RECORD;
BEGIN
FOR r IN (
SELECT conname
FROM pg_constraint c
JOIN pg_class t ON c.conrelid = t.oid
JOIN pg_class ft ON c.confrelid = ft.oid
WHERE c.contype = 'f'
AND t.relname = 'user_file'
AND ft.relname = 'connector_credential_pair'
) LOOP
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT %I', r.conname);
END LOOP;
END$$;
"""
)
)
# Delete cc_pairs
deleted = batch_delete(
bind, "connector_credential_pair", "id", cc_pair_ids, id_type="int"
)
logger.info(f"Deleted {deleted} connector_credential_pair records")
# === Step 5: Delete orphaned connectors ===
if userfile_only_connector_ids:
deleted = batch_delete(
bind, "connector", "id", userfile_only_connector_ids, id_type="int"
)
logger.info(f"Deleted {deleted} orphaned connector records")
# === Step 6: Delete orphaned credentials ===
if userfile_only_credential_ids:
# Clean up credential__user_group mappings first
deleted = batch_delete(
bind,
"credential__user_group",
"credential_id",
userfile_only_credential_ids,
id_type="int",
)
logger.info(f"Deleted {deleted} credential__user_group records")
# Delete credentials
deleted = batch_delete(
bind, "credential", "id", userfile_only_credential_ids, id_type="int"
)
logger.info(f"Deleted {deleted} orphaned credential records")
logger.info("Migration 5 (legacy data cleanup) completed successfully")
def downgrade() -> None:
"""Cannot restore deleted data - requires backup restoration."""
logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!")
logger.error("Data restoration requires backup files or database backup.")
raise NotImplementedError(
"Downgrade of legacy data cleanup is not supported. "
"Deleted data must be restored from backups."
)

View File

@@ -0,0 +1,193 @@
"""Migration 4: User file UUID primary key swap
Revision ID: 7cc3fcc116c1
Revises: 16c37a30adf2
Create Date: 2025-09-22 09:54:38.292952
This migration performs the critical UUID primary key swap on user_file table.
It updates all foreign key references to use UUIDs instead of integers.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "7cc3fcc116c1"
down_revision = "16c37a30adf2"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Swap user_file primary key from integer to UUID."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# Verify we're in the expected state
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
if "new_id" not in user_file_columns:
logger.warning(
"user_file.new_id not found - migration may have already been applied"
)
return
logger.info("Starting UUID primary key swap...")
# === Step 1: Update persona__user_file foreign key to UUID ===
logger.info("Updating persona__user_file foreign key...")
# Drop existing foreign key constraints
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_uuid_fkey"
)
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
)
# Create new foreign key to user_file.new_id
op.create_foreign_key(
"persona__user_file_user_file_id_fkey",
"persona__user_file",
"user_file",
local_cols=["user_file_id_uuid"],
remote_cols=["new_id"],
)
# Drop the old integer column and rename UUID column
op.execute("ALTER TABLE persona__user_file DROP COLUMN IF EXISTS user_file_id")
op.alter_column(
"persona__user_file",
"user_file_id_uuid",
new_column_name="user_file_id",
existing_type=psql.UUID(as_uuid=True),
nullable=False,
)
# Recreate composite primary key
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_pkey"
)
op.execute(
"ALTER TABLE persona__user_file ADD PRIMARY KEY (persona_id, user_file_id)"
)
logger.info("Updated persona__user_file to use UUID foreign key")
# === Step 2: Perform the primary key swap on user_file ===
logger.info("Swapping user_file primary key to UUID...")
# Drop the primary key constraint
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_pkey")
# Drop the old id column and rename new_id to id
op.execute("ALTER TABLE user_file DROP COLUMN IF EXISTS id")
op.alter_column(
"user_file",
"new_id",
new_column_name="id",
existing_type=psql.UUID(as_uuid=True),
nullable=False,
)
# Set default for new inserts
op.alter_column(
"user_file",
"id",
existing_type=psql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
)
# Create new primary key
op.execute("ALTER TABLE user_file ADD PRIMARY KEY (id)")
logger.info("Swapped user_file primary key to UUID")
# === Step 3: Update foreign key constraints ===
logger.info("Updating foreign key constraints...")
# Recreate persona__user_file foreign key to point to user_file.id
# Drop existing FK first to break dependency on the unique constraint
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
)
# Drop the unique constraint on (formerly) new_id BEFORE recreating the FK,
# so the FK will bind to the primary key instead of the unique index.
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS uq_user_file_new_id")
# Now recreate FK to the primary key column
op.create_foreign_key(
"persona__user_file_user_file_id_fkey",
"persona__user_file",
"user_file",
local_cols=["user_file_id"],
remote_cols=["id"],
)
# Add foreign keys for project__user_file
existing_fks = inspector.get_foreign_keys("project__user_file")
has_user_file_fk = any(
fk.get("referred_table") == "user_file"
and fk.get("constrained_columns") == ["user_file_id"]
for fk in existing_fks
)
if not has_user_file_fk:
op.create_foreign_key(
"fk_project__user_file_user_file_id",
"project__user_file",
"user_file",
["user_file_id"],
["id"],
)
logger.info("Added project__user_file -> user_file foreign key")
has_project_fk = any(
fk.get("referred_table") == "user_project"
and fk.get("constrained_columns") == ["project_id"]
for fk in existing_fks
)
if not has_project_fk:
op.create_foreign_key(
"fk_project__user_file_project_id",
"project__user_file",
"user_project",
["project_id"],
["id"],
)
logger.info("Added project__user_file -> user_project foreign key")
# === Step 4: Mark files for document_id migration ===
logger.info("Marking files for background document_id migration...")
logger.info("Migration 4 (UUID primary key swap) completed successfully")
logger.info(
"NOTE: Background task will update document IDs in Vespa and search_doc"
)
def downgrade() -> None:
"""Revert UUID primary key back to integer (data destructive!)."""
logger.error("CRITICAL: Downgrading UUID primary key swap is data destructive!")
logger.error(
"This will break all UUID-based references created after the migration."
)
logger.error("Only proceed if absolutely necessary and have backups.")
# The downgrade would need to:
# 1. Add back integer columns
# 2. Generate new sequential IDs
# 3. Update all foreign key references
# 4. Swap primary keys back
# This is complex and risky, so we raise an error instead
raise NotImplementedError(
"Downgrade of UUID primary key swap is not supported due to data loss risk. "
"Manual intervention with data backup/restore is required."
)

View File

@@ -0,0 +1,257 @@
"""Migration 1: User file schema additions
Revision ID: 9b66d3156fc6
Revises: b4ef3ae0bf6e
Create Date: 2025-09-22 09:42:06.086732
This migration adds new columns and tables without modifying existing data.
It is safe to run and can be easily rolled back.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "9b66d3156fc6"
down_revision = "b4ef3ae0bf6e"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add new columns and tables without modifying existing data."""
# Enable pgcrypto for UUID generation
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
bind = op.get_bind()
inspector = sa.inspect(bind)
# === USER_FILE: Add new columns ===
logger.info("Adding new columns to user_file table...")
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
# Check if ID is already UUID (in case of re-run after partial migration)
id_is_uuid = any(
col["name"] == "id" and "uuid" in str(col["type"]).lower()
for col in inspector.get_columns("user_file")
)
# Add transitional UUID column only if ID is not already UUID
if "new_id" not in user_file_columns and not id_is_uuid:
op.add_column(
"user_file",
sa.Column(
"new_id",
psql.UUID(as_uuid=True),
nullable=True,
server_default=sa.text("gen_random_uuid()"),
),
)
op.create_unique_constraint("uq_user_file_new_id", "user_file", ["new_id"])
logger.info("Added new_id column to user_file")
# Add status column
if "status" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"status",
sa.Enum(
"PROCESSING",
"COMPLETED",
"FAILED",
"CANCELED",
name="userfilestatus",
native_enum=False,
),
nullable=False,
server_default="PROCESSING",
),
)
logger.info("Added status column to user_file")
# Add other tracking columns
if "chunk_count" not in user_file_columns:
op.add_column(
"user_file", sa.Column("chunk_count", sa.Integer(), nullable=True)
)
logger.info("Added chunk_count column to user_file")
if "last_accessed_at" not in user_file_columns:
op.add_column(
"user_file",
sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=True),
)
logger.info("Added last_accessed_at column to user_file")
if "needs_project_sync" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"needs_project_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
logger.info("Added needs_project_sync column to user_file")
if "last_project_sync_at" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"last_project_sync_at", sa.DateTime(timezone=True), nullable=True
),
)
logger.info("Added last_project_sync_at column to user_file")
if "document_id_migrated" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"document_id_migrated",
sa.Boolean(),
nullable=False,
server_default=sa.text("true"),
),
)
logger.info("Added document_id_migrated column to user_file")
# === USER_FOLDER -> USER_PROJECT rename ===
table_names = set(inspector.get_table_names())
if "user_folder" in table_names:
logger.info("Updating user_folder table...")
# Make description nullable first
op.alter_column("user_folder", "description", nullable=True)
# Rename table if user_project doesn't exist
if "user_project" not in table_names:
op.execute("ALTER TABLE user_folder RENAME TO user_project")
logger.info("Renamed user_folder to user_project")
elif "user_project" in table_names:
# If already renamed, ensure column nullability
project_cols = [col["name"] for col in inspector.get_columns("user_project")]
if "description" in project_cols:
op.alter_column("user_project", "description", nullable=True)
# Add instructions column to user_project
inspector = sa.inspect(bind) # Refresh after rename
if "user_project" in inspector.get_table_names():
project_columns = [col["name"] for col in inspector.get_columns("user_project")]
if "instructions" not in project_columns:
op.add_column(
"user_project",
sa.Column("instructions", sa.String(), nullable=True),
)
logger.info("Added instructions column to user_project")
# === CHAT_SESSION: Add project_id ===
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "project_id" not in chat_session_columns:
op.add_column(
"chat_session",
sa.Column("project_id", sa.Integer(), nullable=True),
)
logger.info("Added project_id column to chat_session")
# === PERSONA__USER_FILE: Add UUID column ===
persona_user_file_columns = [
col["name"] for col in inspector.get_columns("persona__user_file")
]
if "user_file_id_uuid" not in persona_user_file_columns:
op.add_column(
"persona__user_file",
sa.Column("user_file_id_uuid", psql.UUID(as_uuid=True), nullable=True),
)
logger.info("Added user_file_id_uuid column to persona__user_file")
# === PROJECT__USER_FILE: Create new table ===
if "project__user_file" not in inspector.get_table_names():
op.create_table(
"project__user_file",
sa.Column("project_id", sa.Integer(), nullable=False),
sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("project_id", "user_file_id"),
)
op.create_index(
"idx_project__user_file_user_file_id",
"project__user_file",
["user_file_id"],
)
logger.info("Created project__user_file table")
logger.info("Migration 1 (schema additions) completed successfully")
def downgrade() -> None:
"""Remove added columns and tables."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of schema additions...")
# Drop project__user_file table
if "project__user_file" in inspector.get_table_names():
op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
op.drop_table("project__user_file")
logger.info("Dropped project__user_file table")
# Remove columns from persona__user_file
if "persona__user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
if "user_file_id_uuid" in columns:
op.drop_column("persona__user_file", "user_file_id_uuid")
logger.info("Dropped user_file_id_uuid from persona__user_file")
# Remove columns from chat_session
if "chat_session" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("chat_session")]
if "project_id" in columns:
op.drop_column("chat_session", "project_id")
logger.info("Dropped project_id from chat_session")
# Rename user_project back to user_folder and remove instructions
if "user_project" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_project")]
if "instructions" in columns:
op.drop_column("user_project", "instructions")
op.execute("ALTER TABLE user_project RENAME TO user_folder")
op.alter_column("user_folder", "description", nullable=False)
logger.info("Renamed user_project back to user_folder")
# Remove columns from user_file
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
columns_to_drop = [
"document_id_migrated",
"last_project_sync_at",
"needs_project_sync",
"last_accessed_at",
"chunk_count",
"status",
]
for col in columns_to_drop:
if col in columns:
op.drop_column("user_file", col)
logger.info(f"Dropped {col} from user_file")
if "new_id" in columns:
op.drop_constraint("uq_user_file_new_id", "user_file", type_="unique")
op.drop_column("user_file", "new_id")
logger.info("Dropped new_id from user_file")
# Drop enum type if no columns use it
bind.execute(sa.text("DROP TYPE IF EXISTS userfilestatus"))
logger.info("Downgrade completed successfully")

View File

@@ -0,0 +1,27 @@
"""add_user_oauth_token_to_slack_bot
Revision ID: b4ef3ae0bf6e
Revises: 505c488f6662
Create Date: 2025-08-26 17:47:41.788462
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4ef3ae0bf6e"
down_revision = "505c488f6662"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add user_token column to slack_bot table
op.add_column("slack_bot", sa.Column("user_token", sa.LargeBinary(), nullable=True))
def downgrade() -> None:
# Remove user_token column from slack_bot table
op.drop_column("slack_bot", "user_token")

View File

@@ -84,7 +84,34 @@ def upgrade() -> None:
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
if tool["in_code_tool_id"] in existing_tool_ids:
in_code_id = tool["in_code_tool_id"]
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(

View File

@@ -93,7 +93,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
task_logger.error(
f"Recieved non-sync CC Pair {cc_pair.id} for external "
f"Received non-sync CC Pair {cc_pair.id} for external "
f"group sync. Actual access type: {cc_pair.access_type}"
)
return False

View File

@@ -17,6 +17,7 @@ from ee.onyx.server.enterprise_settings.api import (
from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
@@ -170,6 +171,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
include_router_with_global_prefix_prepended(application, evals_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names
from ee.onyx.db.standard_answer import find_matching_standard_answers
from onyx.configs.constants import MessageType
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_messages_by_sessions
@@ -193,7 +193,7 @@ def _handle_standard_answers(
db_session.commit()
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
emoji=ONYX_BOT_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,

View File

View File

@@ -0,0 +1,32 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.auth.users import current_cloud_superuser
from onyx.background.celery.apps.client import celery_app as client_app
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.models import User
from onyx.evals.models import EvalConfigurationOptions
from onyx.server.evals.models import EvalRunAck
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/evals")
@router.post("/eval_run", response_model=EvalRunAck)
def eval_run(
request: EvalConfigurationOptions,
user: User = Depends(current_cloud_superuser),
) -> EvalRunAck:
"""
Run an evaluation with the given message and optional dataset.
This endpoint requires a valid API key for authentication.
"""
client_app.send_task(
OnyxCeleryTask.EVAL_RUN_TASK,
kwargs={
"configuration_dict": request.model_dump(),
},
)
return EvalRunAck(success=True)

View File

@@ -182,7 +182,6 @@ def admin_get_chat_sessions(
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
)
for chat in chat_sessions

View File

@@ -37,24 +37,51 @@ def get_embedding_model(
model_name: str,
max_context_length: int,
) -> "SentenceTransformer":
"""
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
"""
from sentence_transformers import SentenceTransformer # type: ignore
global _GLOBAL_MODELS_DICT # A dictionary to store models
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
"""
Build RoPE cos/sin caches once on the final device/dtype so later forwards only read.
Works by calling the underlying HF model directly with dummy IDs/attention.
"""
try:
# ensure > max seq after tokenization
# Ideally we would use the saved tokenizer, but whatever it's ok
# we'll make an assumption about tokenization here
long_text = "x " * (target_len * 2)
_ = st_model.encode(
[long_text],
batch_size=1,
convert_to_tensor=True,
show_progress_bar=False,
normalize_embeddings=False,
)
logger.info("RoPE pre-warm successful")
except Exception as e:
logger.warning(f"RoPE pre-warm skipped/failed: {e}")
global _GLOBAL_MODELS_DICT
if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
# Transformer need to be downloaded to be loaded locally. This does not mean
# data is sent to remote servers for inference, however the remote code can
# be fairly arbitrary so only use trusted models
model = SentenceTransformer(
model_name_or_path=model_name,
trust_remote_code=True,
)
model.max_seq_length = max_context_length
_prewarm_rope(model, max_context_length)
_GLOBAL_MODELS_DICT[model_name] = model
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
else:
model = _GLOBAL_MODELS_DICT[model_name]
if max_context_length != model.max_seq_length:
model.max_seq_length = max_context_length
prev = getattr(model, "_rope_prewarmed_to", 0)
if max_context_length > int(prev or 0):
_prewarm_rope(model, max_context_length)
return _GLOBAL_MODELS_DICT[model_name]

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.access.models import DocumentAccess
@@ -10,6 +11,7 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.document import get_access_info_for_document
from onyx.db.document import get_access_info_for_documents
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -124,3 +126,25 @@ def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> b
),
)
return _source_should_fetch_permissions_during_indexing_func(source)
def get_access_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
user_files = (
db_session.query(UserFile)
.options(joinedload(UserFile.user)) # Eager load the user relationship
.filter(UserFile.id.in_(user_file_ids))
.all()
)
return {
str(user_file.id): DocumentAccess.build(
user_emails=[user_file.user.email] if user_file.user else [],
user_groups=[],
is_public=True if user_file.user is None else False,
external_user_emails=[],
external_user_group_ids=[],
)
for user_file in user_files
}

View File

@@ -35,14 +35,24 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import build_citation_map_from_numbers
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
from onyx.chat.stream_processing.citation_processing import (
normalize_square_bracket_citations_to_double_with_links,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.db.chat import create_search_doc_from_saved_search_doc
from onyx.db.chat import update_db_session_with_messages
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import SearchDoc
from onyx.db.models import Tool
from onyx.db.tools import get_tools
from onyx.file_store.models import ChatFileType
@@ -52,12 +62,14 @@ from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -309,6 +321,52 @@ def _get_existing_clarification_request(
return clarification, original_question, chat_history_string
def _persist_final_docs_and_citations(
db_session: Session,
context_llm_docs: list[Any] | None,
full_answer: str | None,
) -> tuple[list[SearchDoc], dict[int, int] | None]:
"""Persist final documents from in-context docs and derive citation mapping.
Returns the list of persisted `SearchDoc` records and an optional
citation map translating inline [[n]] references to DB doc indices.
"""
final_documents_db: list[SearchDoc] = []
citations_map: dict[int, int] | None = None
if not context_llm_docs:
return final_documents_db, citations_map
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
for saved_doc in saved_search_docs:
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
db_session.add(db_doc)
final_documents_db.append(db_doc)
db_session.flush()
cited_numbers: set[int] = set()
try:
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
matches = re.findall(
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
)
for match in matches:
for num_str in match.split(","):
num = int(num_str.strip())
cited_numbers.add(num)
except Exception:
cited_numbers = set()
if cited_numbers and final_documents_db:
translations = build_citation_map_from_numbers(
cited_numbers=cited_numbers,
db_docs=final_documents_db,
)
citations_map = translations or None
return final_documents_db, citations_map
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
@@ -405,21 +463,28 @@ def clarifier(
active_source_type_descriptions_str = ""
if graph_config.inputs.persona:
assistant_system_prompt = (
assistant_system_prompt = PromptTemplate(
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
) + "\n\n"
).build()
if graph_config.inputs.persona.task_prompt:
assistant_task_prompt = (
"\n\nHere are more specifications from the user:\n\n"
+ (graph_config.inputs.persona.task_prompt)
+ PromptTemplate(graph_config.inputs.persona.task_prompt).build()
)
else:
assistant_task_prompt = ""
else:
assistant_system_prompt = DEFAULT_DR_SYSTEM_PROMPT + "\n\n"
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
assistant_task_prompt = ""
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
@@ -448,6 +513,11 @@ def clarifier(
graph_config.inputs.files
)
# Use project/search context docs if available to enable citation mapping
context_llm_docs = getattr(
graph_config.inputs.prompt_builder, "context_llm_docs", None
)
if not (force_use_tool and force_use_tool.force_use):
if not use_tool_calling_llm or len(available_tools) == 1:
@@ -562,10 +632,37 @@ def clarifier(
active_source_type_descriptions_str=active_source_type_descriptions_str,
)
if context_llm_docs:
persona = graph_config.inputs.persona
if persona is not None:
prompt_config = PromptConfig.from_model(persona)
else:
prompt_config = PromptConfig(
system_prompt=assistant_system_prompt,
task_prompt="",
datetime_aware=True,
)
system_prompt_to_use = build_citations_system_message(
prompt_config
).content
user_prompt_to_use = build_citations_user_message(
user_query=original_question,
files=[],
prompt_config=prompt_config,
context_docs=context_llm_docs,
all_doc_useful=False,
history_message=chat_history_string,
context_type="user files",
).content
else:
system_prompt_to_use = assistant_system_prompt
user_prompt_to_use = decision_prompt + assistant_task_prompt
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
assistant_system_prompt,
decision_prompt + assistant_task_prompt,
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
@@ -578,6 +675,8 @@ def clarifier(
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
@@ -585,19 +684,32 @@ def clarifier(
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):
full_answer = full_response.full_answer
full_answer = (
normalize_square_bracket_citations_to_double_with_links(
full_response.full_answer
)
)
else:
full_answer = None
# Persist final documents and derive citations when using in-context docs
final_documents_db, citations_map = _persist_final_docs_and_citations(
db_session=db_session,
context_llm_docs=context_llm_docs,
full_answer=full_answer,
)
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=full_answer,
token_count=len(llm_tokenizer.encode(full_answer or "")),
citations=citations_map,
final_documents=final_documents_db or None,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=len(llm_tokenizer.encode(full_answer or "")),
)
db_session.commit()

View File

@@ -42,6 +42,7 @@ from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
@@ -225,7 +226,7 @@ def closer(
research_type = graph_config.behavior.research_type
assistant_system_prompt = state.assistant_system_prompt
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
@@ -349,6 +350,13 @@ def closer(
uploaded_context=uploaded_context,
)
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
@@ -18,6 +19,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -56,6 +59,9 @@ def process_llm_stream(
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
@@ -102,6 +108,9 @@ def process_llm_stream(
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
@@ -111,6 +120,14 @@ def process_llm_stream(
writer,
)
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer

View File

@@ -1,6 +1,7 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -73,6 +74,7 @@ def basic_search(
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
@@ -141,6 +143,15 @@ def basic_search(
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
@@ -153,6 +164,8 @@ def basic_search(
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph

View File

@@ -5,12 +5,12 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
@@ -47,7 +47,7 @@ def is_reducer(
doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = chunks_or_sections_to_search_docs(doc_list)
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs

View File

@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
)
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.context.search.models import SearchDoc
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
@@ -266,7 +266,7 @@ def convert_inference_sections_to_search_docs(
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = chunks_or_sections_to_search_docs(inference_sections)
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet

View File

@@ -24,6 +24,7 @@ class GraphInputs(BaseModel):
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
project_instructions: str | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult

View File

@@ -1,4 +1,3 @@
import os
import re
from collections.abc import Callable
from collections.abc import Iterator
@@ -8,18 +7,10 @@ from typing import Any
from typing import cast
from typing import Literal
from typing import TypedDict
from uuid import UUID
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
@@ -32,9 +23,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswer
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
@@ -42,25 +30,16 @@ from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DISPATCH_SEP_CHAR
from onyx.configs.constants import FORMAT_DOCS_SEPARATOR
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.db.tools import get_tool_by_name
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
@@ -77,15 +56,12 @@ from onyx.prompts.agent_search import (
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketObj
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
@@ -156,120 +132,6 @@ def format_entity_term_extraction(
return "\n".join(entity_strs + relationship_strs + term_strs)
def get_test_config(
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> GraphConfig:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona)
search_tool = SearchTool(
tool_id=get_tool_by_name(SearchTool._NAME, db_session).id,
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
document_pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
graph_inputs = GraphInputs(
persona=search_request.persona,
rerank_settings=search_tool_config.rerank_settings,
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
structured_response_format=answer_style_config.structured_response_format,
)
using_tool_calling_llm = explicit_tool_calling_supported(
primary_llm.config.model_provider, primary_llm.config.model_name
)
graph_tooling = GraphTooling(
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=[search_tool],
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=UUID(chat_session_id),
message_id=1,
)
search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=False,
allow_refinement=True,
)
graph_config = GraphConfig(
inputs=graph_inputs,
tooling=graph_tooling,
persistence=graph_persistence,
behavior=search_behavior_config,
)
return graph_config
def get_persona_agent_prompt_expressions(
persona: Persona | None,
) -> PersonaPromptExpressions:

View File

@@ -115,7 +115,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -20,6 +20,7 @@ import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_POOL_OVERFLOW
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
@@ -83,11 +84,11 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
SqlEngine.init_engine(
pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW
)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
@@ -316,12 +317,12 @@ celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -0,0 +1,113 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.user_file_processing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_concurrency = CELERY_WORKER_PRIMARY_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -0,0 +1,22 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# User file processing worker configuration
worker_concurrency = CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -26,6 +26,26 @@ CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = [
{
"name": "check-for-user-file-processing",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "user-file-docid-migration",
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "check-for-kg-processing",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
@@ -89,17 +109,6 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-folder-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
"schedule": timedelta(
days=1
), # This should essentially always be triggered manually for user folder updates.
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,

View File

@@ -28,9 +28,6 @@ from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import (
delete_userfiles_for_cc_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
@@ -484,12 +481,6 @@ def monitor_connector_deletion_taskset(
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# delete all userfiles for the cc_pair
delete_userfiles_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -85,8 +85,10 @@ from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
@@ -1268,6 +1270,8 @@ def _docprocessing_task(
tenant_id: str,
batch_num: int,
) -> None:
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic()
if tenant_id:
@@ -1369,6 +1373,14 @@ def _docprocessing_task(
f"Processing {len(documents)} documents through indexing pipeline"
)
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_metadata,
)
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
@@ -1378,7 +1390,8 @@ def _docprocessing_task(
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
index_attempt_metadata=index_attempt_metadata,
request_id=index_attempt_metadata.request_id,
adapter=adapter,
)
# Update batch completion and document counts atomically using database coordination

View File

@@ -0,0 +1,35 @@
from typing import Any
from celery import shared_task
from celery import Task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.evals.eval import run_eval
from onyx.evals.models import EvalConfigurationOptions
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EVAL_RUN_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def eval_run_task(
self: Task,
*,
configuration_dict: dict[str, Any],
) -> None:
"""Background task to run an evaluation with the given configuration"""
try:
configuration = EvalConfigurationOptions.model_validate(configuration_dict)
run_eval(configuration, remote_dataset_name=configuration.dataset_name)
logger.info("Successfully completed eval run task")
except Exception:
logger.error("Failed to run eval task")
raise

View File

@@ -889,6 +889,12 @@ def monitor_celery_queues_helper(
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_user_file_processing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
n_user_file_project_sync = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -916,6 +922,8 @@ def monitor_celery_queues_helper(
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"user_file_processing={n_user_file_processing} "
f"user_file_project_sync={n_user_file_project_sync} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -0,0 +1,656 @@
import datetime
import time
from collections.abc import Sequence
from typing import Any
from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import UserFileStatus
from onyx.db.models import FileRecord
from onyx.db.models import SearchDoc
from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import S3BackedFileStore
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_pool import get_redis_client
def _as_uuid(value: str | UUID) -> UUID:
"""Return a UUID, accepting either a UUID or a string-like value."""
return value if isinstance(value, UUID) else UUID(str(value))
def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
soft_time_limit=300,
bind=True,
ignore_result=True,
)
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_PROCESSING_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# Do not overlap generator runs
if not lock.acquire(blocking=False):
return None
enqueued = 0
try:
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.status == UserFileStatus.PROCESSING
)
)
.scalars()
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -> None:
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
)
return None
documents: list[Document] = []
try:
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if not uf:
task_logger.warning(
f"process_single_user_file - UserFile not found id={user_file_id}"
)
return None
if uf.status != UserFileStatus.PROCESSING:
task_logger.info(
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
)
return None
connector = LocalFileConnector(
file_locations=[uf.file_id],
file_names=[uf.name] if uf.name else None,
zip_metadata={},
)
connector.load_credentials({})
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
search_settings_list = get_active_search_settings_list(db_session)
current_search_settings = next(
(
search_settings_instance
for search_settings_instance in search_settings_list
if search_settings_instance.status.is_current()
),
None,
)
if current_search_settings is None:
raise RuntimeError(
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
)
try:
for batch in connector.load_from_state():
documents.extend(batch)
adapter = UserFileIndexingAdapter(
tenant_id=tenant_id,
db_session=db_session,
)
# Set up indexing pipeline components
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=current_search_settings,
)
information_content_classification_model = (
InformationContentClassificationModel()
)
document_index = get_default_document_index(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
# update the doument id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
request_id=None,
adapter=adapter,
)
task_logger.info(
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
)
if index_pipeline_result.failures:
task_logger.error(
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
except Exception as e:
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
elapsed = time.monotonic() - start
task_logger.info(
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
)
return None
except Exception as e:
# Attempt to mark the file as failed
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
soft_time_limit=300,
bind=True,
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_PROJECT_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
try:
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.needs_project_sync.is_(True)
and UserFile.status == UserFileStatus.COMPLETED
)
)
.scalars()
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str
) -> None:
"""Process a single user file project sync."""
task_logger.info(
f"process_single_user_file_project_sync - Starting id={user_file_id}"
)
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
)
return None
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
)
return None
project_ids = [project.id for project in user_file.projects]
chunks_affected = retry_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
)
user_file.needs_project_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
db_session.add(user_file)
db_session.commit()
except Exception as e:
task_logger.exception(
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
file_lock.release()
return None
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
user_prefix = "USER_FILE_CONNECTOR__"
file_prefix = "FILE_CONNECTOR__"
if old_id.startswith(user_prefix):
remainder = old_id[len(user_prefix) :]
return file_prefix + remainder
return old_id
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
resp = http_client.get(base_url, params=params, timeout=None)
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def _update_document_id_in_vespa(
*,
index_name: str,
old_doc_id: str,
new_doc_id: str,
user_project_ids: list[int] | None = None,
) -> None:
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
task_logger.debug(f"Vespa selection: {selection}")
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request: dict[str, Any] = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
if user_project_ids is not None:
update_request["fields"][USER_PROJECT] = {
"assign": user_project_ids
}
r = http_client.put(vespa_url, json=update_request)
r.raise_for_status()
if not continuation:
break
@shared_task(
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
ignore_result=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
bind=True,
)
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
- Update `search_doc.document_id` to the same UUID string.
"""
try:
with get_session_with_current_tenant() as db_session:
active_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
active_settings.primary,
active_settings.secondary,
)
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
index_name = "danswer_index"
# Fetch mappings of legacy -> new ids
rows = db_session.execute(
sa.select(
UserFile.document_id.label("document_id"),
UserFile.id.label("id"),
).where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
).all()
# dedupe by old document_id
seen: set[str] = set()
for row in rows:
old_doc_id = str(row.document_id)
new_uuid = str(row.id)
if not old_doc_id or not new_uuid or old_doc_id in seen:
continue
seen.add(old_doc_id)
# collect user project ids for a combined Vespa update
user_project_ids: list[int] | None = None
try:
uf = db_session.get(UserFile, UUID(new_uuid))
if uf is not None:
user_project_ids = [project.id for project in uf.projects]
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
)
try:
_update_document_id_in_vespa(
index_name=index_name,
old_doc_id=old_doc_id,
new_doc_id=new_uuid,
user_project_ids=user_project_ids,
)
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
)
# Update search_doc records to refer to the UUID string
uf_id_subq = (
sa.select(sa.cast(UserFile.id, sa.String))
.where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
SearchDoc.document_id == UserFile.document_id,
)
.correlate(SearchDoc)
.scalar_subquery()
)
db_session.execute(
sa.update(SearchDoc)
.where(
sa.exists(
sa.select(sa.literal(1)).where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
SearchDoc.document_id == UserFile.document_id,
)
)
)
.values(document_id=uf_id_subq)
)
# Mark all processed user_files as migrated
db_session.execute(
sa.update(UserFile)
.where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
.values(document_id_migrated=True)
)
db_session.commit()
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
try:
store = get_default_file_store()
# Only supported for S3-backed stores where we can manipulate object keys
if isinstance(store, S3BackedFileStore):
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.file_id.like("plaintext_%"),
)
)
.scalars()
.all()
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(
Bucket=fr.bucket_name, Key=fr.object_key
)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed plaintext object normalize for "
f"id={fr.file_id} - {e.__class__.__name__}"
)
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
)
else:
task_logger.info(
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
)
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
)
task_logger.info(
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
)
return True
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
)
return False

View File

@@ -414,8 +414,14 @@ def monitor_document_set_taskset(
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
has_connector_pairs = bool(document_set.connector_credential_pairs)
# Federated connectors should keep a document set alive even without cc pairs.
has_federated_connectors = bool(
getattr(document_set, "federated_connectors", [])
)
if not has_connector_pairs and not has_federated_connectors:
# If there are no connectors of any kind, delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set: document_set={document_set_id}"

View File

@@ -0,0 +1,16 @@
"""Factory stub for running the user file processing Celery worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.user_file_processing import celery_app
return celery_app
app = get_app()

View File

@@ -28,7 +28,6 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
@@ -64,9 +63,11 @@ from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -100,6 +101,8 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
from onyx.connectors.factory import instantiate_connector
task = attempt.connector_credential_pair.connector.input_type
try:
@@ -283,6 +286,8 @@ def _run_indexing(
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
@@ -567,6 +572,13 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_md,
)
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
@@ -578,7 +590,8 @@ def _run_indexing(
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
request_id=index_attempt_md.request_id,
adapter=adapter,
)
batch_num += 1

View File

@@ -62,6 +62,7 @@ class Answer:
use_agentic_search: bool = False,
research_type: ResearchType | None = None,
research_plan: dict[str, Any] | None = None,
project_instructions: str | None = None,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
self._processed_stream: list[AnswerStreamPart] | None = None
@@ -97,6 +98,7 @@ class Answer:
prompt_builder=prompt_builder,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
project_instructions=project_instructions,
)
self.graph_tooling = GraphTooling(
primary_llm=llm,

View File

@@ -32,6 +32,7 @@ from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -40,7 +41,9 @@ from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
)
from onyx.llm.models import PreviousMessage
from onyx.llm.override_models import LLMOverride
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.onyxbot.slack.models import SlackContext
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.tool_implementations.custom.custom_tool import (
@@ -63,6 +66,9 @@ def prepare_chat_message_request(
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
llm_override: LLMOverride | None = None,
allowed_tool_ids: list[int] | None = None,
slack_context: SlackContext | None = None,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -88,6 +94,9 @@ def prepare_chat_message_request(
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
llm_override=llm_override,
allowed_tool_ids=allowed_tool_ids,
slack_context=slack_context, # Pass Slack context
)
@@ -335,6 +344,45 @@ def reorganize_citations(
return new_answer, list(new_citation_info.values())
def build_citation_map_from_infos(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Translate a list of streaming CitationInfo objects into a mapping of
citation number -> saved search doc DB id.
Always cites the first instance of a document_id and assumes db_docs are
ordered as shown to the user (display order).
"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
for db_doc in db_docs:
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
citation_to_saved_doc_id_map: dict[int, int] = {}
for citation in citations_list:
if citation.citation_num not in citation_to_saved_doc_id_map:
saved_id = doc_id_to_saved_doc_id_map.get(citation.document_id)
if saved_id is not None:
citation_to_saved_doc_id_map[citation.citation_num] = saved_id
return citation_to_saved_doc_id_map
def build_citation_map_from_numbers(
cited_numbers: list[int] | set[int], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Translate parsed citation numbers (e.g., from [[n]]) into a mapping of
citation number -> saved search doc DB id by positional index.
"""
citation_to_saved_doc_id_map: dict[int, int] = {}
for num in sorted(set(cited_numbers)):
idx = num - 1
if 0 <= idx < len(db_docs):
citation_to_saved_doc_id_map[num] = db_docs[idx].id
return citation_to_saved_doc_id_map
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:

View File

@@ -5,6 +5,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from typing import Protocol
from uuid import UUID
from sqlalchemy.orm import Session
@@ -18,6 +19,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatBasicResponse
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import PromptConfig
@@ -35,6 +37,7 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
@@ -63,9 +66,13 @@ from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.projects import get_project_instructions
from onyx.db.projects import get_user_files_from_project
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import load_all_chat_files
from onyx.kg.models import KGException
from onyx.llm.exceptions import GenAIDisabledException
@@ -101,6 +108,7 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
@@ -119,6 +127,55 @@ class PartialResponse(Protocol):
) -> ChatMessage: ...
def _build_project_llm_docs(
project_file_ids: list[str] | None,
in_memory_user_files: list[InMemoryChatFile] | None,
) -> list[LlmDoc]:
"""Construct `LlmDoc` objects for project-scoped user files for citation flow."""
project_llm_docs: list[LlmDoc] = []
if not project_file_ids or not in_memory_user_files:
return project_llm_docs
project_file_id_set = set(project_file_ids)
for f in in_memory_user_files:
# Only include files that belong to the project (not ad-hoc uploads)
if project_file_id_set and (f.file_id in project_file_id_set):
try:
text_content = f.content.decode("utf-8", errors="ignore")
except Exception:
text_content = ""
# Build a short blurb from the file content for better UI display
blurb = (
(text_content[:200] + "...")
if len(text_content) > 200
else text_content
)
# Provide basic metadata to improve SavedSearchDoc display
file_metadata: dict[str, str | list[str]] = {
"filename": f.filename or str(f.file_id),
"file_type": f.file_type.value,
}
project_llm_docs.append(
LlmDoc(
document_id=str(f.file_id),
content=text_content,
blurb=blurb,
semantic_identifier=f.filename or str(f.file_id),
source_type=DocumentSource.USER_FILE,
metadata=file_metadata,
updated_at=None,
link=build_frontend_file_url(str(f.file_id)),
source_links=None,
match_highlights=None,
)
)
return project_llm_docs
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> MessageSpecificCitations:
@@ -436,26 +493,28 @@ def stream_chat_message_objects(
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
user_folder_ids = new_msg_req.user_folder_ids or []
user_file_ids: list[UUID] = []
if persona.user_files:
for file in persona.user_files:
user_file_ids.append(file.id)
if persona.user_folders:
for folder in persona.user_folders:
user_folder_ids.append(folder.id)
for uf in persona.user_files:
user_file_ids.append(uf.id)
if new_msg_req.current_message_files:
for fd in new_msg_req.current_message_files:
uid = fd.get("user_file_id")
if uid is not None:
user_file_ids.append(uid)
# Load in user files into memory and create search tool override kwargs if needed
# if we have enough tokens and no folders, we don't need to use search
# if we have enough tokens, we don't need to use search
# we can just pass them into the prompt directly
(
in_memory_user_files,
user_file_models,
search_tool_override_kwargs_for_user_files,
) = parse_user_files(
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
user_file_ids=user_file_ids or [],
project_id=chat_session.project_id,
db_session=db_session,
persona=persona,
actual_user_input=message_text,
@@ -464,16 +523,37 @@ def stream_chat_message_objects(
if not search_tool_override_kwargs_for_user_files:
latest_query_files.extend(in_memory_user_files)
project_file_ids = []
if chat_session.project_id:
project_file_ids.extend(
[
file.file_id
for file in get_user_files_from_project(
chat_session.project_id, user_id, db_session
)
]
)
# we don't want to attach project files to the user message
if user_message:
attach_files_to_chat_message(
chat_message=user_message,
files=[
new_file.to_file_descriptor() for new_file in latest_query_files
new_file.to_file_descriptor()
for new_file in latest_query_files
if project_file_ids is not None
and (new_file.file_id not in project_file_ids)
],
db_session=db_session,
commit=False,
)
# Build project context docs for citation flow if project files are present
project_llm_docs: list[LlmDoc] = _build_project_llm_docs(
project_file_ids=project_file_ids,
in_memory_user_files=in_memory_user_files,
)
selected_db_search_docs = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
@@ -559,12 +639,22 @@ def stream_chat_message_objects(
else:
prompt_config = PromptConfig.from_model(persona)
# Retrieve project-specific instructions if this chat session is associated with a project.
project_instructions: str | None = (
get_project_instructions(
db_session=db_session, project_id=chat_session.project_id
)
if persona.is_default_persona
else None
) # if the persona is not default, we don't want to use the project instructions
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
structured_response_format=new_msg_req.structured_response_format,
)
has_project_files = project_file_ids is not None and len(project_file_ids) > 0
tool_dict = construct_tools(
persona=persona,
@@ -574,9 +664,17 @@ def stream_chat_message_objects(
llm=llm,
fast_llm=fast_llm,
run_search_setting=(
retrieval_options.run_search
if retrieval_options
else OptionalSearchSetting.AUTO
OptionalSearchSetting.NEVER
if (
chat_session.project_id
and not has_project_files
and persona.is_default_persona
)
else (
retrieval_options.run_search
if retrieval_options
else OptionalSearchSetting.AUTO
)
),
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
@@ -603,6 +701,7 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
slack_context=new_msg_req.slack_context, # Pass Slack context from request
)
tools: list[Tool] = []
@@ -617,6 +716,7 @@ def stream_chat_message_objects(
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
yield UserKnowledgeFilePacket(
user_files=[
@@ -624,6 +724,8 @@ def stream_chat_message_objects(
id=str(file.file_id), type=file.file_type, name=file.filename
)
for file in in_memory_user_files
if project_file_ids is not None
and (file.file_id not in project_file_ids)
]
)
@@ -642,6 +744,10 @@ def stream_chat_message_objects(
single_message_history=single_message_history,
)
if project_llm_docs and not search_tool_override_kwargs_for_user_files:
# Store for downstream streaming to wire citations and final_documents
prompt_builder.context_llm_docs = project_llm_docs
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
@@ -670,6 +776,7 @@ def stream_chat_message_objects(
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
project_instructions=project_instructions,
)
# Process streamed packets using the new packet processing module

View File

@@ -4,9 +4,9 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
@@ -76,6 +76,7 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
@@ -132,6 +133,10 @@ class AnswerPromptBuilder:
self.raw_user_uploaded_files = raw_user_uploaded_files
self.single_message_history = single_message_history
# Optional: if the prompt includes explicit context documents (e.g., project files),
# store them here so downstream streaming can reference them for citation mapping.
self.context_llm_docs: list[LlmDoc] | None = None
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -196,10 +201,6 @@ class AnswerPromptBuilder:
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe?

View File

@@ -0,0 +1,10 @@
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from onyx.llm.models import PreviousMessage
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]

View File

@@ -12,6 +12,35 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def normalize_square_bracket_citations_to_double_with_links(text: str) -> str:
"""
Normalize citation markers in the text:
- Convert bare double-bracket citations without links `[[n]]` to `[[n]]()`
- Convert single-bracket citations `[n]` to `[[n]]()`
Leaves existing linked citations like `[[n]](http...)` unchanged.
"""
if not text:
return ""
# Add empty parens to bare double-bracket citations without a link: [[n]] -> [[n]]()
pattern_double_no_link = re.compile(r"\[\[(\d+)\]\](?!\()")
def _repl_double(match: re.Match[str]) -> str:
num = match.group(1)
return f"[[{num}]]()"
text = pattern_double_no_link.sub(_repl_double, text)
# Convert single [n] not already [[n]] to [[n]]()
pattern_single = re.compile(r"(?<!\[)\[(\d+)\](?!\])")
def _repl_single(match: re.Match[str]) -> str:
num = match.group(1)
return f"[[{num}]]()"
return pattern_single.sub(_repl_single, text)
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message

View File

@@ -4,6 +4,8 @@ from sqlalchemy.orm import Session
from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.user_file import update_last_accessed_at_for_user_files
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import get_user_files_as_user
from onyx.file_store.utils import load_in_memory_chat_files
@@ -15,24 +17,24 @@ logger = setup_logger()
def parse_user_files(
user_file_ids: list[int],
user_folder_ids: list[int],
user_file_ids: list[UUID],
db_session: Session,
persona: Persona,
actual_user_input: str,
project_id: int | None,
# should only be None if auth is disabled
user_id: UUID | None,
) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]:
"""
Parse user files and folders into in-memory chat files and create search tool override kwargs.
Only creates SearchToolOverrideKwargs if token overflow occurs or folders are present.
Parse user files and project into in-memory chat files and create search tool override kwargs.
Only creates SearchToolOverrideKwargs if token overflow occurs.
Args:
user_file_ids: List of user file IDs to load
user_folder_ids: List of user folder IDs to load
db_session: Database session
persona: Persona to calculate available tokens
actual_user_input: User's input message for token calculation
project_id: Project ID to validate file ownership
user_id: User ID to validate file ownership
Returns:
@@ -40,37 +42,56 @@ def parse_user_files(
loaded user files,
user file models,
search tool override kwargs if token
overflow or folders present
overflow
)
"""
# Return empty results if no files or folders specified
if not user_file_ids and not user_folder_ids:
# Return empty results if no files or project specified
if not user_file_ids and not project_id:
return [], [], None
project_user_file_ids = []
if project_id:
project_user_file_ids.extend(
[
file.id
for file in get_user_files_from_project(project_id, user_id, db_session)
]
)
# Combine user-provided and project-derived user file IDs
combined_user_file_ids = user_file_ids + project_user_file_ids or []
# Load user files from the database into memory
user_files = load_in_memory_chat_files(
user_file_ids or [],
user_folder_ids or [],
combined_user_file_ids,
db_session,
)
user_file_models = get_user_files_as_user(
user_file_ids or [],
user_folder_ids or [],
combined_user_file_ids,
user_id,
db_session,
)
# Update last accessed at for the user files which are used in the chat
if user_file_ids or project_user_file_ids:
# update_last_accessed_at_for_user_files expects list[UUID]
update_last_accessed_at_for_user_files(
combined_user_file_ids,
db_session,
)
# Calculate token count for the files, need to import here to avoid circular import
# TODO: fix this
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.db.user_file import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
# calculate_user_files_token_count now expects list[UUID]
total_tokens = calculate_user_files_token_count(
user_file_ids or [],
user_folder_ids or [],
combined_user_file_ids,
db_session,
)
@@ -86,20 +107,22 @@ def parse_user_files(
have_enough_tokens = total_tokens <= available_tokens
# If we have enough tokens and no folders, we don't need search
# If we have enough tokens, we don't need search
# we can just pass them into the prompt directly
if have_enough_tokens and not user_folder_ids:
if have_enough_tokens:
# No search tool override needed - files can be passed directly
return user_files, user_file_models, None
# Token overflow or folders present - need to use search tool
# Token overflow - need to use search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=have_enough_tokens,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=have_enough_tokens,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
user_file_ids=user_file_ids or [],
project_id=(
project_id if persona.is_default_persona else None
), # if the persona is not default, we don't want to use the project files
)
return user_files, user_file_models, override_kwargs

View File

@@ -65,19 +65,19 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 12))
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
PASSWORD_MAX_LENGTH = int(os.getenv("PASSWORD_MAX_LENGTH", 64))
PASSWORD_REQUIRE_UPPERCASE = (
os.environ.get("PASSWORD_REQUIRE_UPPERCASE", "true").lower() == "true"
os.environ.get("PASSWORD_REQUIRE_UPPERCASE", "false").lower() == "true"
)
PASSWORD_REQUIRE_LOWERCASE = (
os.environ.get("PASSWORD_REQUIRE_LOWERCASE", "true").lower() == "true"
os.environ.get("PASSWORD_REQUIRE_LOWERCASE", "false").lower() == "true"
)
PASSWORD_REQUIRE_DIGIT = (
os.environ.get("PASSWORD_REQUIRE_DIGIT", "true").lower() == "true"
os.environ.get("PASSWORD_REQUIRE_DIGIT", "false").lower() == "true"
)
PASSWORD_REQUIRE_SPECIAL_CHAR = (
os.environ.get("PASSWORD_REQUIRE_SPECIAL_CHAR", "true").lower() == "true"
os.environ.get("PASSWORD_REQUIRE_SPECIAL_CHAR", "false").lower() == "true"
)
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
@@ -355,6 +355,26 @@ CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
)
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
try:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get(
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
)
)
except ValueError:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192
@@ -658,8 +678,8 @@ LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
)
# Logs Onyx only model interactions like prompts, responses, messages etc.
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
LOG_ONYX_MODEL_INTERACTIONS = (
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
@@ -677,6 +697,17 @@ LOG_POSTGRES_CONN_COUNTS = (
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
#####
# Braintrust Configuration
#####
# Enable Braintrust tracing for LangGraph/LangChain applications
BRAINTRUST_ENABLED = os.environ.get("BRAINTRUST_ENABLED", "").lower() == "true"
# Braintrust project name
BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx")
BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or ""
# Maximum concurrency for Braintrust evaluations
BRAINTRUST_MAX_CONCURRENCY = int(os.environ.get("BRAINTRUST_MAX_CONCURRENCY") or 5)
TOKEN_BUDGET_GLOBALLY_ENABLED = (
os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true"
)

View File

@@ -3,7 +3,6 @@ import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page

View File

@@ -7,6 +7,8 @@ from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
SOURCE_TYPE = "source_type"
@@ -76,6 +78,9 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
"celery_worker_user_file_processing"
)
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -112,7 +117,6 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -203,6 +207,8 @@ class DocumentSource(str, Enum):
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
# Special case for user files
USER_FILE = "user_file"
class FederatedConnectorSource(str, Enum):
@@ -298,6 +304,7 @@ class FileOrigin(str, Enum):
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
QUERY_HISTORY_CSV = "query_history_csv"
USER_FILE = "user_file"
class FileType(str, Enum):
@@ -343,6 +350,9 @@ class OnyxCeleryQueues:
# Indexing queue
USER_FILES_INDEXING = "user_files_indexing"
# User file processing queue
USER_FILE_PROCESSING = "user_file_processing"
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
@@ -368,7 +378,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
CLOUD_PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
@@ -390,6 +400,12 @@ class OnyxRedisLocks:
# KG processing
KG_PROCESSING_LOCK = "da_lock:kg_processing"
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
@@ -448,8 +464,6 @@ class OnyxCeleryTask:
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_pidbox"
)
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -457,7 +471,12 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
# User file processing
CHECK_FOR_USER_FILE_PROCESSING = "check_for_user_file_processing"
PROCESS_SINGLE_USER_FILE = "process_single_user_file"
CHECK_FOR_USER_FILE_PROJECT_SYNC = "check_for_user_file_project_sync"
PROCESS_SINGLE_USER_FILE_PROJECT_SYNC = "process_single_user_file_project_sync"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
@@ -490,6 +509,7 @@ class OnyxCeleryTask:
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
# chat retention
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
@@ -497,6 +517,8 @@ class OnyxCeleryTask:
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
EVAL_RUN_TASK = "eval_run_task"
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"

View File

@@ -3,28 +3,26 @@ import os
#####
# Onyx Slack Bot Configs
#####
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5"))
# How much of the available input context can be used for thread context
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
)
ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5"))
# If the LLM fails to answer, Onyx can still show the "Reference Documents"
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
"DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get(
"ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER", ""
).lower() not in ["false", ""]
# When Onyx is considering a message, what emoji does it react with
DANSWER_REACT_EMOJI = os.environ.get("DANSWER_REACT_EMOJI") or "eyes"
ONYX_BOT_REACT_EMOJI = os.environ.get("ONYX_BOT_REACT_EMOJI") or "eyes"
# When User needs more help, what should the emoji be
DANSWER_FOLLOWUP_EMOJI = os.environ.get("DANSWER_FOLLOWUP_EMOJI") or "sos"
ONYX_BOT_FOLLOWUP_EMOJI = os.environ.get("ONYX_BOT_FOLLOWUP_EMOJI") or "sos"
# What kind of message should be shown when someone gives an AI answer feedback to OnyxBot
# Defaults to Private if not provided or invalid
# Private: Only visible to user clicking the feedback
# Anonymous: Public but anonymous
# Public: Visible with the user name who submitted the feedback
DANSWER_BOT_FEEDBACK_VISIBILITY = (
os.environ.get("DANSWER_BOT_FEEDBACK_VISIBILITY") or "private"
ONYX_BOT_FEEDBACK_VISIBILITY = (
os.environ.get("ONYX_BOT_FEEDBACK_VISIBILITY") or "private"
)
# Should OnyxBot send an apology message if it's not able to find an answer
# That way the user isn't confused as to why OnyxBot reacted but then said nothing
@@ -34,40 +32,38 @@ NOTIFY_SLACKBOT_NO_ANSWER = (
)
# Mostly for debugging purposes but it's for explaining what went wrong
# if OnyxBot couldn't find an answer
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
ONYX_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
"ONYX_BOT_DISPLAY_ERROR_MSGS", ""
).lower() not in [
"false",
"",
]
# Default is only respond in channels that are included by a slack config set in the UI
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
ONYX_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("ONYX_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Maximum Questions Per Minute, Default Uncapped
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
ONYX_BOT_MAX_QPM = int(os.environ.get("ONYX_BOT_MAX_QPM") or 0) or None
# Maximum time to wait when a question is queued
DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME") or 180)
ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180)
# Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback.
# Set to 0 to disable it (default)
DANSWER_BOT_FEEDBACK_REMINDER = int(
os.environ.get("DANSWER_BOT_FEEDBACK_REMINDER") or 0
)
ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0)
# Set to True to rephrase the Slack users messages
DANSWER_BOT_REPHRASE_MESSAGE = (
os.environ.get("DANSWER_BOT_REPHRASE_MESSAGE", "").lower() == "true"
ONYX_BOT_REPHRASE_MESSAGE = (
os.environ.get("ONYX_BOT_REPHRASE_MESSAGE", "").lower() == "true"
)
# DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
# ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
# responses OnyxBot can send in a given time period.
# Set to 0 to disable the limit.
DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int(
os.environ.get("ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000")
)
# DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
# ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number
# of seconds until the response limit is reset.
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
os.environ.get("DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int(
os.environ.get("ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400")
)

View File

@@ -32,6 +32,7 @@ def _create_image_section(
image_data: bytes,
parent_file_name: str,
display_name: str,
media_type: str | None = None,
link: str | None = None,
idx: int = 0,
) -> tuple[ImageSection, str | None]:
@@ -58,6 +59,9 @@ def _create_image_section(
image_data=image_data,
file_id=file_id,
display_name=display_name,
media_type=(
media_type if media_type is not None else "application/octet-stream"
),
link=link,
file_origin=FileOrigin.CONNECTOR,
)
@@ -123,6 +127,7 @@ def _process_file(
image_data=image_data,
parent_file_name=file_id,
display_name=title,
media_type=file_type,
)
return [
@@ -194,6 +199,7 @@ def _process_file(
image_data=img_data,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
media_type="application/octet-stream", # Default media type for embedded images
idx=idx,
)
sections.append(image_section)

View File

@@ -168,8 +168,8 @@ class FirefliesConnector(PollConnector, LoadConnector):
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
received_transcripts = response.json()
parsed_transcripts = received_transcripts.get("data", {}).get(
"transcripts", []
)

View File

@@ -61,6 +61,19 @@ EMAIL_FIELDS = [
add_retries = retry_builder(tries=50, max_delay=30)
def _is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
or "failedPrecondition" in error_message
)
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
@@ -307,33 +320,42 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
):
full_threads = execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
"Skipping Gmail sync for %s because the mailbox is disabled.",
user_email,
)
continue
raise
if doc_batch:
yield doc_batch
@@ -349,35 +371,44 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
external_access=ExternalAccess(
external_user_emails={user_email},
external_user_group_ids=set(),
is_public=False,
),
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
external_access=ExternalAccess(
external_user_emails={user_email},
external_user_group_ids=set(),
is_public=False,
),
)
)
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
callback.progress("retrieve_all_slim_documents", 1)
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
user_email,
)
continue
raise
if doc_batch:
yield doc_batch

View File

@@ -57,7 +57,9 @@ from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_a
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -141,6 +143,10 @@ class SharepointAuthMethod(Enum):
CERTIFICATE = "certificate"
class SizeCapExceeded(Exception):
"""Exception raised when the size cap is exceeded."""
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
"""Load certificate from .pfx file for MSAL authentication"""
try:
@@ -239,7 +245,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
Behavior:
- Checks `Content-Length` first and aborts early if it exceeds `cap`.
- Otherwise streams the body in chunks and stops once `cap` is surpassed.
- Raises `RuntimeError('size_cap_exceeded')` when the cap would be exceeded.
- Raises `SizeCapExceeded` when the cap would be exceeded.
- Returns the full bytes if the content fits within `cap`.
"""
with requests.get(url, stream=True, timeout=timeout) as resp:
@@ -253,7 +259,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
logger.warning(
f"Content-Length {content_len} exceeds cap {cap}; skipping download."
)
raise RuntimeError("size_cap_exceeded")
raise SizeCapExceeded("pre_download")
buf = io.BytesIO()
# Stream in 64KB chunks; adjust if needed for slower networks.
@@ -266,11 +272,32 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
logger.warning(
f"Streaming download exceeded cap {cap} bytes; aborting early."
)
raise RuntimeError("size_cap_exceeded")
raise SizeCapExceeded("during_download")
return buf.getvalue()
def _download_via_sdk_with_cap(
driveitem: DriveItem, bytes_allowed: int, chunk_size: int = 64 * 1024
) -> bytes:
"""Use the Office365 SDK streaming download with a hard byte cap.
Raises SizeCapExceeded("during_sdk_download") if the cap would be exceeded.
"""
buf = io.BytesIO()
def on_chunk(bytes_read: int) -> None:
# bytes_read is total bytes seen so far per SDK contract
if bytes_read > bytes_allowed:
raise SizeCapExceeded("during_sdk_download")
# modifies the driveitem to change its download behavior
driveitem.download_session(buf, chunk_downloaded=on_chunk, chunk_size=chunk_size)
# Execute the configured request with retries using existing helper
sleep_and_retry(driveitem.context, "download_session")
return buf.getvalue()
def _convert_driveitem_to_document_with_permissions(
driveitem: DriveItem,
drive_name: str,
@@ -289,6 +316,16 @@ def _convert_driveitem_to_document_with_permissions(
file_size: int | None = None
try:
item_json = driveitem.to_json()
mime_type = item_json.get("file", {}).get("mimeType")
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
# for now, this skip must happen before we download the file
# Similar to Google Drive, we'll just semi-silently skip excluded image types
logger.debug(
f"Skipping malformed or excluded mime type {mime_type} for {driveitem.name}"
)
return None
size_value = item_json.get("size")
if size_value is not None:
file_size = int(size_value)
@@ -311,19 +348,16 @@ def _convert_driveitem_to_document_with_permissions(
content_bytes: bytes | None = None
if download_url:
try:
# Use this to test the sdk size cap
# raise requests.RequestException("test")
content_bytes = _download_with_cap(
download_url,
REQUEST_TIMEOUT_SECONDS,
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
)
except RuntimeError as e:
if "size_cap_exceeded" in str(e):
logger.warning(
f"Skipping '{driveitem.name}' exceeded size cap during streaming."
)
return None
else:
raise
except SizeCapExceeded as e:
logger.warning(f"Skipping '{driveitem.name}' exceeded size cap: {str(e)}")
return None
except requests.RequestException as e:
status = e.response.status_code if e.response is not None else -1
logger.warning(
@@ -332,13 +366,15 @@ def _convert_driveitem_to_document_with_permissions(
# Fallback to SDK content if needed
if content_bytes is None:
content = sleep_and_retry(driveitem.get_content(), "get_content")
if content is None or not isinstance(
getattr(content, "value", None), (bytes, bytearray)
):
logger.warning(f"Could not access content for '{driveitem.name}'")
raise ValueError(f"Could not access content for '{driveitem.name}'")
content_bytes = bytes(content.value)
try:
content_bytes = _download_via_sdk_with_cap(
driveitem, SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
)
except SizeCapExceeded:
logger.warning(
f"Skipping '{driveitem.name}' exceeded size cap during SDK streaming."
)
return None
sections: list[TextSection | ImageSection] = []
file_ext = driveitem.name.split(".")[-1]
@@ -348,6 +384,7 @@ def _convert_driveitem_to_document_with_permissions(
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
)
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
# NOTE: this if should use is_valid_image_type instead with mime_type
image_section, _ = store_image_and_create_section(
image_data=content_bytes,
file_id=driveitem.id,
@@ -358,23 +395,45 @@ def _convert_driveitem_to_document_with_permissions(
sections.append(image_section)
else:
# Note: we don't process Onyx metadata for connectors like Drive & Sharepoint, but could
def _store_embedded_image(img_data: bytes, img_name: str) -> None:
try:
mime_type = get_image_type_from_bytes(img_data)
except ValueError:
logger.debug(
"Skipping embedded image with unknown format for %s",
driveitem.name,
)
return
# The only mime type that would be returned by get_image_type_from_bytes that is in
# EXCLUDED_IMAGE_TYPES is image/gif.
if mime_type in EXCLUDED_IMAGE_TYPES:
logger.debug(
"Skipping embedded image of excluded type %s for %s",
mime_type,
driveitem.name,
)
return
image_section, _ = store_image_and_create_section(
image_data=img_data,
file_id=f"{driveitem.id}_img_{len(sections)}",
display_name=img_name or f"{driveitem.name} - image {len(sections)}",
file_origin=FileOrigin.CONNECTOR,
)
image_section.link = driveitem.web_url
sections.append(image_section)
extraction_result = extract_text_and_images(
file=io.BytesIO(content_bytes), file_name=driveitem.name
file=io.BytesIO(content_bytes),
file_name=driveitem.name,
image_callback=_store_embedded_image,
)
if extraction_result.text_content:
sections.append(
TextSection(link=driveitem.web_url, text=extraction_result.text_content)
)
for idx, (img_data, img_name) in enumerate(extraction_result.embedded_images):
image_section, _ = store_image_and_create_section(
image_data=img_data,
file_id=f"{driveitem.id}_img_{idx}",
display_name=img_name or f"{driveitem.name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
image_section.link = driveitem.web_url
sections.append(image_section)
# Any embedded images were stored via the callback; the returned list may be empty.
if include_permissions and ctx is not None:
logger.info(f"Getting external access for {driveitem.name}")
@@ -717,6 +776,7 @@ class SharepointConnector(
for folder_part in site_descriptor.folder_path.split("/"):
root_folder = root_folder.get_by_path(folder_part)
# TODO: consider ways to avoid materializing the entire list of files in memory
query = root_folder.get_files(
recursive=True,
page_size=1000,
@@ -825,6 +885,7 @@ class SharepointConnector(
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
# TODO: consider ways to avoid materializing the entire list of files in memory
query = root_folder.get_files(
recursive=True,
page_size=1000,
@@ -973,6 +1034,8 @@ class SharepointConnector(
all_pages = pages_data.get("value", [])
# Handle pagination if there are more pages
# TODO: This accumulates all pages in memory and can be heavy on large tenants.
# We should process each page incrementally to avoid unbounded growth.
while "@odata.nextLink" in pages_data:
next_url = pages_data["@odata.nextLink"]
response = requests.get(
@@ -986,7 +1049,7 @@ class SharepointConnector(
# Filter pages based on time window if specified
if start is not None or end is not None:
filtered_pages = []
filtered_pages: list[dict[str, Any]] = []
for page in all_pages:
page_modified = page.get("lastModifiedDateTime")
if page_modified:

View File

@@ -761,7 +761,7 @@ class SlackConnector(
Step 2: Loop through each channel. For each channel:
Step 2.1: Get messages within the time range.
Step 2.2: Process messages in parallel, yield back docs.
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
Step 2.3: Update checkpoint with new_oldest, seen_thread_ts, and current_channel.
Slack returns messages from newest to oldest, so we need to keep track of
the latest message we've seen in each channel.
Step 2.4: If there are no more messages in the channel, switch the current
@@ -837,7 +837,8 @@ class SlackConnector(
channel_message_ts = checkpoint.channel_completion_map.get(channel_id)
if channel_message_ts:
latest = channel_message_ts
# Set oldest to the checkpoint timestamp to resume from where we left off
oldest = channel_message_ts
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
@@ -855,7 +856,8 @@ class SlackConnector(
f"{latest=}"
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
# message_batch[0] is the newest message (Slack returns newest to oldest)
new_oldest = message_batch[0]["ts"] if message_batch else latest
num_threads_start = len(seen_thread_ts)
@@ -906,15 +908,14 @@ class SlackConnector(
num_threads_processed = len(seen_thread_ts) - num_threads_start
# calculate a percentage progress for the current channel by determining
# our viable range start and end, and the latest timestamp we are querying
# up to
new_latest_seconds_epoch = SecondsSinceUnixEpoch(new_latest)
if new_latest_seconds_epoch > end:
# how much of the time range we've processed so far
new_oldest_seconds_epoch = SecondsSinceUnixEpoch(new_oldest)
range_start = start if start else max(0, channel_created)
if new_oldest_seconds_epoch < range_start:
range_complete = 0.0
else:
range_complete = end - new_latest_seconds_epoch
range_complete = new_oldest_seconds_epoch - range_start
range_start = max(0, channel_created)
range_total = end - range_start
if range_total <= 0:
range_total = 1
@@ -935,7 +936,7 @@ class SlackConnector(
)
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_latest
checkpoint.channel_completion_map[channel["id"]] = new_oldest
# bypass channels where the first set of messages seen are all bots
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch

View File

@@ -75,7 +75,7 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
"""
ttl_ms: int | None = None
retry_after_value: list[str] | None = None
retry_after_value: str | None = None
retry_after_header_name: Optional[str] = None
duration_s: float = 1.0 # seconds
@@ -103,14 +103,21 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
)
retry_after_value = response.headers.get(retry_after_header_name)
if not retry_after_value:
retry_after_header_value = response.headers.get(retry_after_header_name)
if not retry_after_header_value:
raise ValueError(
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
)
# Handle case where header value might be a list
retry_after_value = (
retry_after_header_value[0]
if isinstance(retry_after_header_value, list)
else retry_after_header_value
)
retry_after_value_int = int(
retry_after_value[0]
retry_after_value
) # will raise ValueError if somehow we can't convert to int
jitter = retry_after_value_int * 0.25 * random.random()
duration_s = retry_after_value_int + jitter

View File

@@ -1,5 +1,6 @@
import copy
import time
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
@@ -14,6 +15,9 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
@@ -47,14 +51,30 @@ class ZendeskCredentialsNotSetUpError(PermissionError):
class ZendeskClient:
def __init__(self, subdomain: str, email: str, token: str):
def __init__(
self,
subdomain: str,
email: str,
token: str,
calls_per_minute: int | None = None,
):
self.base_url = f"https://{subdomain}.zendesk.com/api/v2"
self.auth = (f"{email}/token", token)
self.make_request = request_with_rate_limit(self, calls_per_minute)
def request_with_rate_limit(
client: ZendeskClient, max_calls_per_minute: int | None = None
) -> Callable[[str, dict[str, Any]], dict[str, Any]]:
@retry_builder()
def make_request(self, endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
@(
rate_limit_builder(max_calls=max_calls_per_minute, period=60)
if max_calls_per_minute
else lambda x: x
)
def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
f"{client.base_url}/{endpoint}", auth=client.auth, params=params
)
if response.status_code == 429:
@@ -72,6 +92,8 @@ class ZendeskClient:
response.raise_for_status()
return response.json()
return make_request
class ZendeskPageResponse(BaseModel):
data: list[dict[str, Any]]
@@ -359,11 +381,13 @@ class ZendeskConnector(
def __init__(
self,
content_type: str = "articles",
calls_per_minute: int | None = None,
) -> None:
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
self.content_tags: dict[str, str] = {}
self.calls_per_minute = calls_per_minute
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -375,7 +399,10 @@ class ZendeskConnector(
self.subdomain = subdomain
self.client = ZendeskClient(
subdomain, credentials["zendesk_email"], credentials["zendesk_token"]
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
calls_per_minute=self.calls_per_minute,
)
return None

View File

@@ -28,6 +28,8 @@ from onyx.indexing.models import DocAwareChunk
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.onyxbot.slack.models import ChannelType
from onyx.onyxbot.slack.models import SlackContext
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -39,6 +41,42 @@ HIGHLIGHT_START_CHAR = "\ue000"
HIGHLIGHT_END_CHAR = "\ue001"
def _should_skip_channel(
channel_id: str,
allowed_private_channel: str | None,
bot_token: str | None,
access_token: str,
include_dm: bool,
) -> bool:
"""
Determine if a channel should be skipped if in bot context. When an allowed_private_channel is passed in,
all other private channels are filtered out except that specific one.
"""
if bot_token and not include_dm:
try:
# Use bot token if available (has full permissions), otherwise fall back to user token
token_to_use = bot_token or access_token
channel_client = WebClient(token=token_to_use)
channel_info = channel_client.conversations_info(channel=channel_id)
if isinstance(channel_info.data, dict) and not _is_public_channel(
channel_info.data
):
# This is a private channel - filter it out
if channel_id != allowed_private_channel:
logger.debug(
f"Skipping message from private channel {channel_id} "
f"(not the allowed private channel: {allowed_private_channel})"
)
return True
except Exception as e:
logger.warning(
f"Could not determine channel type for {channel_id}, filtering out: {e}"
)
return True
return False
def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
# get time filter
time_filter = ""
@@ -64,11 +102,33 @@ def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
]
def _is_public_channel(channel_info: dict[str, Any]) -> bool:
"""Check if a channel is public based on its info"""
# The channel_info structure has a nested 'channel' object
channel = channel_info.get("channel", {})
is_channel = channel.get("is_channel", False)
is_private = channel.get("is_private", False)
is_group = channel.get("is_group", False)
is_mpim = channel.get("is_mpim", False)
is_im = channel.get("is_im", False)
# A public channel is: a channel that is NOT private, NOT a group, NOT mpim, NOT im
is_public = (
is_channel and not is_private and not is_group and not is_mpim and not is_im
)
return is_public
def query_slack(
query_string: str,
original_query: SearchQuery,
access_token: str,
limit: int | None = None,
allowed_private_channel: str | None = None,
bot_token: str | None = None,
include_dm: bool = False,
) -> list[SlackMessage]:
# query slack
slack_client = WebClient(token=access_token)
@@ -79,12 +139,22 @@ def query_slack(
response.validate()
messages: dict[str, Any] = response.get("messages", {})
matches: list[dict[str, Any]] = messages.get("matches", [])
except SlackApiError as e:
logger.error(f"Slack API error in query_slack: {e}")
logger.info(f"Successfully used search_messages, found {len(matches)} messages")
except SlackApiError as slack_error:
logger.error(f"Slack API error in search_messages: {slack_error}")
logger.error(
f"Slack API error details: status={slack_error.response.status_code}, "
f"error={slack_error.response.get('error')}"
)
if "not_allowed_token_type" in str(slack_error):
# Log token type prefix
token_prefix = access_token[:4] if len(access_token) >= 4 else "unknown"
logger.error(f"TOKEN TYPE ERROR: access_token type: {token_prefix}...")
return []
# convert matches to slack messages
slack_messages: list[SlackMessage] = []
filtered_count = 0
for match in matches:
text: str | None = match.get("text")
permalink: str | None = match.get("permalink")
@@ -92,6 +162,13 @@ def query_slack(
channel_id: str | None = match.get("channel", {}).get("id")
channel_name: str | None = match.get("channel", {}).get("name")
username: str | None = match.get("username")
if not username:
# Fallback: try to get from user field if username is missing
user_info = match.get("user", "")
if isinstance(user_info, str) and user_info:
username = user_info # Use user ID as fallback
else:
username = "unknown_user"
score: float = match.get("score", 0.0)
if ( # can't use any() because of type checking :(
not text
@@ -103,6 +180,13 @@ def query_slack(
):
continue
# Apply channel filtering if needed
if _should_skip_channel(
channel_id, allowed_private_channel, bot_token, access_token, include_dm
):
filtered_count += 1
continue
# generate thread id and document id
thread_id = (
permalink.split("?thread_ts=", 1)[1] if "?thread_ts=" in permalink else None
@@ -155,6 +239,11 @@ def query_slack(
)
)
if filtered_count > 0:
logger.info(
f"Channel filtering applied: {filtered_count} messages filtered out, {len(slack_messages)} messages kept"
)
return slack_messages
@@ -291,14 +380,40 @@ def slack_retrieval(
access_token: str,
db_session: Session,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
) -> list[InferenceChunk]:
# query slack
_, fast_llm = get_default_llms()
query_strings = build_slack_queries(query, fast_llm)
results: list[list[SlackMessage]] = run_functions_tuples_in_parallel(
include_dm = False
allowed_private_channel = None
if slack_event_context:
channel_type = slack_event_context.channel_type
if channel_type == ChannelType.IM: # DM with user
include_dm = True
if channel_type == ChannelType.PRIVATE_CHANNEL:
allowed_private_channel = slack_event_context.channel_id
logger.info(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
results = run_functions_tuples_in_parallel(
[
(query_slack, (query_string, query, access_token, limit))
(
query_slack,
(
query_string,
query,
access_token,
limit,
allowed_private_channel,
bot_token,
include_dm,
),
)
for query_string in query_strings
]
)
@@ -307,7 +422,6 @@ def slack_retrieval(
if not slack_messages:
return []
# contextualize the slack messages
thread_texts: list[str] = run_functions_tuples_in_parallel(
[
(get_contextualized_thread_text, (slack_message, access_token))

View File

@@ -1,5 +1,7 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -119,8 +121,8 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
user_file_ids: list[UUID] | None = None
project_id: int | None = None
class IndexFilters(BaseFilters, UserFileFilters):
@@ -355,6 +357,44 @@ class SearchDoc(BaseModel):
secondary_owners: list[str] | None = None
is_internet: bool = False
@classmethod
def from_chunks_or_sections(
cls,
items: "Sequence[InferenceChunk | InferenceSection] | None",
) -> list["SearchDoc"]:
"""Convert a sequence of InferenceChunk or InferenceSection objects to SearchDoc objects."""
if not items:
return []
search_docs = [
cls(
document_id=(
chunk := (
item.center_chunk
if isinstance(item, InferenceSection)
else item
)
).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for item in items
]
return search_docs
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(*args, **kwargs) # type: ignore
initial_dict["updated_at"] = (

View File

@@ -36,6 +36,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.llm.interfaces import LLM
from onyx.onyxbot.slack.models import SlackContext
from onyx.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
@@ -65,6 +66,7 @@ class SearchPipeline:
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
contextual_pruning_config: ContextualPruningConfig | None = None,
slack_context: SlackContext | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
@@ -84,6 +86,7 @@ class SearchPipeline:
self.contextual_pruning_config: ContextualPruningConfig | None = (
contextual_pruning_config
)
self.slack_context: SlackContext | None = slack_context
# Preprocessing steps generate this
self._search_query: SearchQuery | None = None
@@ -162,6 +165,7 @@ class SearchPipeline:
document_index=self.document_index,
db_session=self.db_session,
retrieval_metrics_callback=self.retrieval_metrics_callback,
slack_context=self.slack_context, # Pass Slack context
)
return cast(list[InferenceChunk], self._retrieved_chunks)

View File

@@ -166,9 +166,6 @@ def retrieval_preprocessing(
)
user_file_filters = search_request.user_file_filters
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
user_folder_ids = (
(user_file_filters.user_folder_ids or []) if user_file_filters else []
)
if persona and persona.user_files:
user_file_ids = list(
set(user_file_ids) | set([file.id for file in persona.user_files])
@@ -176,7 +173,7 @@ def retrieval_preprocessing(
final_filters = IndexFilters(
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
project_id=user_file_filters.project_id if user_file_filters else None,
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=time_filter or predicted_time_cutoff,

View File

@@ -30,6 +30,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.onyxbot.slack.models import SlackContext
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -329,6 +330,7 @@ def retrieve_chunks(
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
slack_context: SlackContext | None = None,
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
@@ -341,7 +343,11 @@ def retrieve_chunks(
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session, user_id, query.filters.source_type, query.filters.document_set
db_session,
user_id,
query.filters.source_type,
query.filters.document_set,
slack_context,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()

View File

@@ -118,40 +118,6 @@ def inference_section_from_chunks(
)
def chunks_or_sections_to_search_docs(
items: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:
if not items:
return []
search_docs = [
SearchDoc(
document_id=(
chunk := (
item.center_chunk if isinstance(item, InferenceSection) else item
)
).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for item in items
]
return search_docs
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
try:
# Re-tokenize using the NLTK tokenizer for better matching

View File

@@ -28,13 +28,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
from onyx.agents.agent_search.utils import create_citation_format_list
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDocs
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc as ServerSearchDoc
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.db.models import AgentSearchMetrics
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
@@ -47,17 +45,15 @@ from onyx.db.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_best_persona_id_for_user
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import SubQueryDetail
from onyx.server.query_and_chat.models import SubQuestionDetail
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.tools.models import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
@@ -173,6 +169,8 @@ def get_chat_sessions_by_user(
db_session: Session,
include_onyxbot_flows: bool = False,
limit: int = 50,
project_id: int | None = None,
only_non_project_chats: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
@@ -187,6 +185,11 @@ def get_chat_sessions_by_user(
if limit:
stmt = stmt.limit(limit)
if project_id is not None:
stmt = stmt.where(ChatSession.project_id == project_id)
elif only_non_project_chats:
stmt = stmt.where(ChatSession.project_id.is_(None))
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
@@ -260,6 +263,7 @@ def create_chat_session(
prompt_override: PromptOverride | None = None,
onyxbot_flow: bool = False,
slack_thread_id: str | None = None,
project_id: int | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
@@ -269,6 +273,7 @@ def create_chat_session(
prompt_override=prompt_override,
onyxbot_flow=onyxbot_flow,
slack_thread_id=slack_thread_id,
project_id=project_id,
)
db_session.add(chat_session)
@@ -859,90 +864,6 @@ def get_db_search_doc_by_document_id(
return search_doc
def create_search_doc_from_user_file(
db_user_file: UserFile, associated_chat_file: InMemoryChatFile, db_session: Session
) -> SearchDoc:
"""Create a SearchDoc in the database from a UserFile and return it.
This ensures proper ID generation by SQLAlchemy and prevents duplicate key errors.
"""
blurb = ""
if associated_chat_file and associated_chat_file.content:
try:
# Try to decode as UTF-8, but handle errors gracefully
content_sample = associated_chat_file.content[:100]
# Remove null bytes which can cause SQL errors
content_sample = content_sample.replace(b"\x00", b"")
# NOTE(rkuo): this used to be "replace" instead of strict, but
# that would bypass the binary handling below
blurb = content_sample.decode("utf-8", errors="strict")
except Exception:
# If decoding fails completely, provide a generic description
blurb = f"[Binary file: {db_user_file.name}]"
db_search_doc = SearchDoc(
document_id=db_user_file.document_id,
chunk_ind=0, # Default to 0 for user files
semantic_id=db_user_file.name,
link=db_user_file.link_url,
blurb=blurb,
source_type=DocumentSource.FILE, # Assuming internal source for user files
boost=0, # Default boost
hidden=False, # Default visibility
doc_metadata={}, # Empty metadata
score=0.0, # Default score of 0.0 instead of None
is_relevant=None, # No relevance initially
relevance_explanation=None, # No explanation initially
match_highlights=[], # No highlights initially
updated_at=db_user_file.created_at, # Use created_at as updated_at
primary_owners=[], # Empty list instead of None
secondary_owners=[], # Empty list instead of None
is_internet=False, # Not from internet
)
db_session.add(db_search_doc)
db_session.flush() # Get the ID but don't commit yet
return db_search_doc
def translate_db_user_file_to_search_doc(
db_user_file: UserFile, associated_chat_file: InMemoryChatFile
) -> SearchDoc:
blurb = ""
if associated_chat_file and associated_chat_file.content:
try:
# Try to decode as UTF-8, but handle errors gracefully
content_sample = associated_chat_file.content[:100]
# Remove null bytes which can cause SQL errors
content_sample = content_sample.replace(b"\x00", b"")
blurb = content_sample.decode("utf-8", errors="replace")
except Exception:
# If decoding fails completely, provide a generic description
blurb = f"[Binary file: {db_user_file.name}]"
return SearchDoc(
# Don't set ID - let SQLAlchemy auto-generate it
document_id=db_user_file.document_id,
chunk_ind=0, # Default to 0 for user files
semantic_id=db_user_file.name,
link=db_user_file.link_url,
blurb=blurb,
source_type=DocumentSource.FILE, # Assuming internal source for user files
boost=0, # Default boost
hidden=False, # Default visibility
doc_metadata={}, # Empty metadata
score=0.0, # Default score of 0.0 instead of None
is_relevant=None, # No relevance initially
relevance_explanation=None, # No explanation initially
match_highlights=[], # No highlights initially
updated_at=db_user_file.created_at, # Use created_at as updated_at
primary_owners=[], # Empty list instead of None
secondary_owners=[], # Empty list instead of None
is_internet=False, # Not from internet
)
def translate_db_search_doc_to_server_search_doc(
db_search_doc: SearchDoc,
remove_doc_content: bool = False,
@@ -1147,7 +1068,7 @@ def log_agent_sub_question_results(
db_session.add(sub_query_object)
db_session.commit()
search_docs = chunks_or_sections_to_search_docs(
search_docs = ServerSearchDoc.from_chunks_or_sections(
sub_query.retrieved_documents
)
for doc in search_docs:
@@ -1223,12 +1144,29 @@ def create_search_doc_from_inference_section(
def create_search_doc_from_saved_search_doc(
saved_search_doc: SavedSearchDoc,
) -> SearchDoc:
"""Convert SavedSearchDoc to SearchDoc by excluding the additional fields"""
data = saved_search_doc.model_dump()
# Remove the fields that are specific to SavedSearchDoc
data.pop("db_doc_id", None)
# Keep score since SearchDoc has it as an optional field
return SearchDoc(**data)
"""Convert SavedSearchDoc (server model) into DB SearchDoc with correct field mapping."""
return SearchDoc(
document_id=saved_search_doc.document_id,
chunk_ind=saved_search_doc.chunk_ind,
# Map Pydantic semantic_identifier -> DB semantic_id; ensure non-null
semantic_id=saved_search_doc.semantic_identifier or "Unknown",
link=saved_search_doc.link,
blurb=saved_search_doc.blurb,
source_type=saved_search_doc.source_type,
boost=saved_search_doc.boost,
hidden=saved_search_doc.hidden,
# Map metadata -> doc_metadata (DB column name)
doc_metadata=saved_search_doc.metadata,
# SavedSearchDoc.score exists and defaults to 0.0
score=saved_search_doc.score or 0.0,
match_highlights=saved_search_doc.match_highlights,
updated_at=saved_search_doc.updated_at,
primary_owners=saved_search_doc.primary_owners,
secondary_owners=saved_search_doc.secondary_owners,
is_internet=saved_search_doc.is_internet,
is_relevant=saved_search_doc.is_relevant,
relevance_explanation=saved_search_doc.relevance_explanation,
)
def update_db_session_with_messages(
@@ -1251,7 +1189,6 @@ def update_db_session_with_messages(
research_answer_purpose: ResearchAnswerPurpose | None = None,
commit: bool = False,
) -> ChatMessage:
chat_message = (
db_session.query(ChatMessage)
.filter(

View File

@@ -34,7 +34,6 @@ from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserFile
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.server.models import StatusResponse
@@ -805,31 +804,3 @@ def resync_cc_pair(
)
db_session.commit()
def get_connector_credential_pairs_with_user_files(
db_session: Session,
) -> list[ConnectorCredentialPair]:
"""
Get all connector credential pairs that have associated user files.
Args:
db_session: Database session
Returns:
List of ConnectorCredentialPair objects that have user files
"""
return (
db_session.query(ConnectorCredentialPair)
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
.distinct()
.all()
)
def delete_userfiles_for_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
) -> None:
stmt = delete(UserFile).where(UserFile.cc_pair_id == cc_pair_id)
db_session.execute(stmt)

View File

@@ -55,6 +55,7 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.kg.models import KGStage
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
@@ -259,31 +260,48 @@ def get_document_counts_for_cc_pairs(
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
if not cc_pairs:
return []
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
func.count(),
)
.where(
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
# Batch to avoid generating extremely large IN clauses that can blow Postgres stack depth
batch_size = 1000
aggregated_counts: dict[tuple[int, int], int] = {}
for start_idx in range(0, len(cc_ids), batch_size):
batch = cc_ids[start_idx : start_idx + batch_size]
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
func.count(),
)
.where(
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(batch),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
)
)
return db_session.execute(stmt).all() # type: ignore
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
aggregated_counts[(connector_id, credential_id)] = cnt
# Convert aggregated results back to the expected sequence of tuples
return [
(connector_id, credential_id, cnt)
for (connector_id, credential_id), cnt in aggregated_counts.items()
]
# For use with our thread-level parallelism utils. Note that any relationships
@@ -296,6 +314,72 @@ def get_document_counts_for_cc_pairs_parallel(
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def _get_document_counts_for_cc_pairs_batch(
batch: list[tuple[int, int]],
) -> list[tuple[int, int, int]]:
"""Worker for parallel execution: opens its own session per batch."""
if not batch:
return []
with get_session_with_current_tenant() as db_session:
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
func.count(),
)
.where(
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(batch),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
)
)
return db_session.execute(stmt).all() # type: ignore
def get_document_counts_for_cc_pairs_batched_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
batch_size: int = 1000,
max_workers: int | None = None,
) -> Sequence[tuple[int, int, int]]:
"""Parallel variant that batches the IN-clause and runs batches concurrently.
Opens an isolated DB session per batch to avoid sharing a session across threads.
"""
if not cc_pairs:
return []
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
batches: list[list[tuple[int, int]]] = [
cc_ids[i : i + batch_size] for i in range(0, len(cc_ids), batch_size)
]
funcs = [(_get_document_counts_for_cc_pairs_batch, (batch,)) for batch in batches]
results = run_functions_tuples_in_parallel(
functions_with_args=funcs, max_workers=max_workers
)
aggregated_counts: dict[tuple[int, int], int] = {}
for batch_result in results:
if not batch_result:
continue
for connector_id, credential_id, cnt in batch_result:
aggregated_counts[(connector_id, credential_id)] = cnt
return [
(connector_id, credential_id, cnt)
for (connector_id, credential_id), cnt in aggregated_counts.items()
]
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -131,3 +131,10 @@ class EmbeddingPrecision(str, PyEnum):
# good reason to specify anything else
BFLOAT16 = "bfloat16"
FLOAT = "float"
class UserFileStatus(str, PyEnum):
PROCESSING = "PROCESSING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"

View File

@@ -62,6 +62,7 @@ from onyx.db.enums import (
SyncType,
SyncStatus,
MCPAuthenticationType,
UserFileStatus,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -209,9 +210,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
@@ -229,8 +227,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
back_populates="creator",
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder", back_populates="user"
projects: Mapped[list["UserProject"]] = relationship(
"UserProject", back_populates="user"
)
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
# MCP servers accessible to this user
@@ -539,10 +537,6 @@ class ConnectorCredentialPair(Base):
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
user_file: Mapped["UserFile"] = relationship(
"UserFile", back_populates="cc_pair", uselist=False
)
background_errors: Mapped[list["BackgroundError"]] = relationship(
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
)
@@ -1459,9 +1453,6 @@ class FederatedConnector(Base):
class FederatedConnectorOAuthToken(Base):
"""NOTE: in the future, can be made more general to support OAuth tokens
for actions."""
__tablename__ = "federated_connector_oauth_token"
id: Mapped[int] = mapped_column(primary_key=True)
@@ -2037,9 +2028,6 @@ class ChatSession(Base):
Enum(ChatSessionSharedStatus, native_enum=False),
default=ChatSessionSharedStatus.PRIVATE,
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_folder.id"), nullable=True
)
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
@@ -2047,6 +2035,14 @@ class ChatSession(Base):
String, nullable=True, default=None
)
project_id: Mapped[int | None] = mapped_column(
ForeignKey("user_project.id"), nullable=True
)
project: Mapped["UserProject"] = relationship(
"UserProject", back_populates="chat_sessions", foreign_keys=[project_id]
)
# the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the
# `send-message` call will take precedence over these.
@@ -2072,9 +2068,6 @@ class ChatSession(Base):
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
)
@@ -2178,33 +2171,6 @@ class ChatMessage(Base):
)
class ChatFolder(Base):
"""For organizing chat sessions"""
__tablename__ = "chat_folder"
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
user: Mapped[User] = relationship("User", back_populates="chat_folders")
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="folder"
)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, ChatFolder):
return NotImplemented
if self.display_priority == other.display_priority:
# Bigger ID (created later) show earlier
return self.id > other.id
return self.display_priority < other.display_priority
class AgentSubQuestion(Base):
"""
A sub-question is a question that is asked of the LLM to gather supporting
@@ -2636,11 +2602,6 @@ class Persona(Base):
secondary="persona__user_file",
back_populates="assistants",
)
user_folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder",
secondary="persona__user_folder",
back_populates="assistants",
)
labels: Mapped[list["PersonaLabel"]] = relationship(
"PersonaLabel",
secondary=Persona__PersonaLabel.__table__,
@@ -2658,20 +2619,11 @@ class Persona(Base):
)
class Persona__UserFolder(Base):
__tablename__ = "persona__user_folder"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_folder_id: Mapped[int] = mapped_column(
ForeignKey("user_folder.id"), primary_key=True
)
class Persona__UserFile(Base):
__tablename__ = "persona__user_file"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_file_id: Mapped[int] = mapped_column(
user_file_id: Mapped[UUID] = mapped_column(
ForeignKey("user_file.id"), primary_key=True
)
@@ -2783,6 +2735,7 @@ class SlackBot(Base):
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
@@ -3276,23 +3229,37 @@ class InputPrompt__User(Base):
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
class UserFolder(Base):
__tablename__ = "user_folder"
class Project__UserFile(Base):
__tablename__ = "project__user_file"
project_id: Mapped[int] = mapped_column(
ForeignKey("user_project.id"), primary_key=True
)
user_file_id: Mapped[UUID] = mapped_column(
ForeignKey("user_file.id"), primary_key=True
)
class UserProject(Base):
__tablename__ = "user_project"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
description: Mapped[str] = mapped_column(nullable=False)
description: Mapped[str] = mapped_column(nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped["User"] = relationship(back_populates="folders")
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
assistants: Mapped[list["Persona"]] = relationship(
"Persona",
secondary=Persona__UserFolder.__table__,
back_populates="user_folders",
user: Mapped["User"] = relationship(back_populates="projects")
user_files: Mapped[list["UserFile"]] = relationship(
"UserFile",
secondary=Project__UserFile.__table__,
back_populates="projects",
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="project", lazy="selectin"
)
instructions: Mapped[str] = mapped_column(String)
class UserDocument(str, Enum):
@@ -3304,17 +3271,13 @@ class UserDocument(str, Enum):
class UserFile(Base):
__tablename__ = "user_file"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
assistants: Mapped[list["Persona"]] = relationship(
"Persona",
secondary=Persona__UserFile.__table__,
back_populates="user_files",
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
file_id: Mapped[str] = mapped_column(nullable=False)
document_id: Mapped[str] = mapped_column(nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
@@ -3322,17 +3285,38 @@ class UserFile(Base):
default=datetime.datetime.utcnow
)
user: Mapped["User"] = relationship(back_populates="files")
folder: Mapped["UserFolder"] = relationship(back_populates="files")
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=True, unique=True
file_type: Mapped[str] = mapped_column(String, nullable=False)
status: Mapped[UserFileStatus] = mapped_column(
Enum(UserFileStatus, native_enum=False, name="userfilestatus"),
nullable=False,
default=UserFileStatus.PROCESSING,
)
cc_pair: Mapped["ConnectorCredentialPair"] = relationship(
"ConnectorCredentialPair", back_populates="user_file"
needs_project_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
last_accessed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
content_type: Mapped[str | None] = mapped_column(String, nullable=True)
document_id_migrated: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=True
)
projects: Mapped[list["UserProject"]] = relationship(
"UserProject",
secondary=Project__UserFile.__table__,
back_populates="user_files",
lazy="selectin",
)
"""

View File

@@ -33,7 +33,6 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import FullPersonaSnapshot
@@ -237,6 +236,16 @@ def create_update_persona(
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
# Convert incoming string UUIDs to UUID objects for DB operations
converted_user_file_ids = None
if create_persona_request.user_file_ids is not None:
try:
converted_user_file_ids = [
UUID(str_id) for str_id in create_persona_request.user_file_ids
]
except Exception:
raise ValueError("Invalid user_file_ids; must be UUID strings")
persona = upsert_persona(
persona_id=persona_id,
user=user,
@@ -264,8 +273,7 @@ def create_update_persona(
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
user_file_ids=create_persona_request.user_file_ids,
user_folder_ids=create_persona_request.user_folder_ids,
user_file_ids=converted_user_file_ids,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -504,8 +512,7 @@ def upsert_persona(
builtin_persona: bool = False,
is_default_persona: bool | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[int] | None = None,
user_folder_ids: list[int] | None = None,
user_file_ids: list[UUID] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -566,17 +573,6 @@ def upsert_persona(
if not user_files and user_file_ids:
raise ValueError("user_files not found")
# Fetch and attach user_folders by IDs
user_folders = None
if user_folder_ids is not None:
user_folders = (
db_session.query(UserFolder)
.filter(UserFolder.id.in_(user_folder_ids))
.all()
)
if not user_folders and user_folder_ids:
raise ValueError("user_folders not found")
labels = None
if label_ids is not None:
labels = (
@@ -644,10 +640,6 @@ def upsert_persona(
existing_persona.user_files.clear()
existing_persona.user_files = user_files or []
if user_folder_ids is not None:
existing_persona.user_folders.clear()
existing_persona.user_folders = user_folders or []
# We should only update display priority if it is not already set
if existing_persona.display_priority is None:
existing_persona.display_priority = display_priority
@@ -686,7 +678,6 @@ def upsert_persona(
is_default_persona=(
is_default_persona if is_default_persona is not None else False
),
user_folders=user_folders or [],
user_files=user_files or [],
labels=labels or [],
)

177
backend/onyx/db/projects.py Normal file
View File

@@ -0,0 +1,177 @@
import datetime
import uuid
from typing import List
from uuid import UUID
from fastapi import HTTPException
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.models import Project__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.models import UserProject
from onyx.server.documents.connector import upload_files
from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
class CategorizedFilesResult(BaseModel):
user_files: list[UserFile]
non_accepted_files: list[str]
unsupported_files: list[str]
# Allow SQLAlchemy ORM models inside this result container
model_config = ConfigDict(arbitrary_types_allowed=True)
def create_user_files(
files: List[UploadFile],
project_id: int | None,
user: User | None,
db_session: Session,
link_url: str | None = None,
) -> CategorizedFilesResult:
# Categorize the files
categorized_files = categorize_uploaded_files(files)
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
user_files = []
non_accepted_files = categorized_files.non_accepted
unsupported_files = categorized_files.unsupported
# Pair returned storage paths with the same set of acceptable files we uploaded
for file_path, file in zip(
upload_response.file_paths, categorized_files.acceptable
):
new_id = uuid.uuid4()
new_file = UserFile(
id=new_id,
user_id=user.id if user else None,
file_id=file_path,
document_id=str(new_id),
name=file.filename,
token_count=categorized_files.acceptable_file_to_token_count[
file.filename or ""
],
link_url=link_url,
content_type=file.content_type,
file_type=file.content_type,
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
)
# Persist the UserFile first to satisfy FK constraints for association table
db_session.add(new_file)
db_session.flush()
if project_id:
project_to_user_file = Project__UserFile(
project_id=project_id,
user_file_id=new_file.id,
)
db_session.add(project_to_user_file)
user_files.append(new_file)
db_session.commit()
return CategorizedFilesResult(
user_files=user_files,
non_accepted_files=non_accepted_files,
unsupported_files=unsupported_files,
)
def upload_files_to_user_files_with_indexing(
files: List[UploadFile],
project_id: int | None,
user: User | None,
db_session: Session,
) -> CategorizedFilesResult:
# Validate project ownership if a project_id is provided
if project_id is not None and user is not None:
if not check_project_ownership(project_id, user.id, db_session):
raise HTTPException(status_code=404, detail="Project not found")
categorized_files_result = create_user_files(files, project_id, user, db_session)
user_files = categorized_files_result.user_files
non_accepted_files = categorized_files_result.non_accepted_files
unsupported_files = categorized_files_result.unsupported_files
# Trigger per-file processing immediately for the current tenant
tenant_id = get_current_tenant_id()
if non_accepted_files:
for filename in non_accepted_files:
logger.warning(f"Non-accepted file: {filename}")
if unsupported_files:
for filename in unsupported_files:
logger.warning(f"Unsupported file: {filename}")
for user_file in user_files:
task = client_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
)
return CategorizedFilesResult(
user_files=user_files,
non_accepted_files=non_accepted_files,
unsupported_files=unsupported_files,
)
def check_project_ownership(
project_id: int, user_id: UUID | None, db_session: Session
) -> bool:
return (
db_session.query(UserProject)
.filter(UserProject.id == project_id, UserProject.user_id == user_id)
.first()
is not None
)
def get_user_files_from_project(
project_id: int, user_id: UUID | None, db_session: Session
) -> list[UserFile]:
# First check if the user owns the project
if not check_project_ownership(project_id, user_id, db_session):
return []
return (
db_session.query(UserFile)
.join(Project__UserFile)
.filter(Project__UserFile.project_id == project_id)
.all()
)
def get_project_instructions(db_session: Session, project_id: int | None) -> str | None:
"""Return the project's instruction text from the project, else None.
Safe helper that swallows DB errors and returns None on any failure.
"""
if not project_id:
return None
try:
project = (
db_session.query(UserProject)
.filter(UserProject.id == project_id)
.one_or_none()
)
if not project or not project.instructions:
return None
instructions = project.instructions.strip()
return instructions or None
except Exception:
return None

View File

@@ -12,12 +12,14 @@ def insert_slack_bot(
enabled: bool,
bot_token: str,
app_token: str,
user_token: str | None = None,
) -> SlackBot:
slack_bot = SlackBot(
name=name,
enabled=enabled,
bot_token=bot_token,
app_token=app_token,
user_token=user_token,
)
db_session.add(slack_bot)
db_session.commit()
@@ -32,6 +34,7 @@ def update_slack_bot(
enabled: bool,
bot_token: str,
app_token: str,
user_token: str | None = None,
) -> SlackBot:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if slack_bot is None:
@@ -42,6 +45,7 @@ def update_slack_bot(
slack_bot.enabled = enabled
slack_bot.bot_token = bot_token
slack_bot.app_token = app_token
slack_bot.user_token = user_token
db_session.commit()
@@ -74,15 +78,3 @@ def remove_slack_bot(
def fetch_slack_bots(db_session: Session) -> Sequence[SlackBot]:
return db_session.scalars(select(SlackBot)).all()
def fetch_slack_bot_tokens(
db_session: Session, slack_bot_id: int
) -> dict[str, str] | None:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if not slack_bot:
return None
return {
"app_token": slack_bot.app_token,
"bot_token": slack_bot.bot_token,
}

View File

@@ -2,7 +2,6 @@ from typing import Any
from typing import cast
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from uuid import UUID
from sqlalchemy import select
@@ -10,21 +9,12 @@ from sqlalchemy.orm import Session
from onyx.db.models import Tool
from onyx.server.features.tool.models import Header
from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES
from onyx.utils.headers import HeaderItemDict
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
KnowledgeGraphTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
pass
logger = setup_logger()
@@ -124,15 +114,7 @@ def delete_tool__no_commit(tool_id: int, db_session: Session) -> None:
def get_builtin_tool(
db_session: Session,
tool_type: Type[
Union[
"SearchTool",
"ImageGenerationTool",
"WebSearchTool",
"KnowledgeGraphTool",
"OktaProfileTool",
]
],
tool_type: Type[BUILT_IN_TOOL_TYPES],
) -> Tool:
"""
Retrieves a built-in tool from the database based on the tool type.

View File

@@ -0,0 +1,87 @@
import datetime
from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import UserFile
def fetch_chunk_counts_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> list[tuple[str, int]]:
"""
Return a list of (user_file_id, chunk_count) tuples.
If a user_file_id is not found in the database, it will be returned with a chunk_count of 0.
"""
stmt = select(UserFile.id, UserFile.chunk_count).where(
UserFile.id.in_(user_file_ids)
)
results = db_session.execute(stmt).all()
# Create a dictionary of user_file_id to chunk_count
chunk_counts = {str(row.id): row.chunk_count or 0 for row in results}
# Return a list of tuples, preserving `None` for documents not found or with
# an unknown chunk count. Callers should handle the `None` case and fall
# back to an existence check against the vector DB if necessary.
return [
(user_file_id, chunk_counts.get(user_file_id, 0))
for user_file_id in user_file_ids
]
def calculate_user_files_token_count(file_ids: list[UUID], db_session: Session) -> int:
"""Calculate total token count for specified files"""
total_tokens = 0
# Get tokens from individual files
if file_ids:
file_tokens = (
db_session.query(func.sum(UserFile.token_count))
.filter(UserFile.id.in_(file_ids))
.scalar()
or 0
)
total_tokens += file_tokens
return total_tokens
def fetch_user_project_ids_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch user project ids for specified user files"""
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [project.id for project in user_file.projects]
for user_file in results
}
def update_last_accessed_at_for_user_files(
user_file_ids: list[UUID],
db_session: Session,
) -> None:
"""Update `last_accessed_at` to now (UTC) for the given user files."""
if not user_file_ids:
return
now = datetime.datetime.now(datetime.timezone.utc)
(
db_session.query(UserFile)
.filter(UserFile.id.in_(user_file_ids))
.update({UserFile.last_accessed_at: now}, synchronize_session=False)
)
db_session.commit()
def get_file_id_by_user_file_id(user_file_id: str, db_session: Session) -> str | None:
user_file = db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
if user_file:
return user_file.file_id
return None

View File

@@ -116,8 +116,7 @@ class VespaDocumentUserFields:
Fields that are specific to the user who is indexing the document.
"""
user_file_id: str | None = None
user_folder_id: str | None = None
user_projects: list[int] | None = None
@dataclass

View File

@@ -176,6 +176,11 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field user_project type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -68,8 +68,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
@@ -766,12 +765,9 @@ class VespaIndex(DocumentIndex):
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
if user_fields is not None:
if user_fields.user_file_id is not None:
update_dict["fields"][USER_FILE] = {"assign": user_fields.user_file_id}
if user_fields.user_folder_id is not None:
update_dict["fields"][USER_FOLDER] = {
"assign": user_fields.user_folder_id
if user_fields.user_projects is not None:
update_dict["fields"][USER_PROJECT] = {
"assign": user_fields.user_projects
}
if not update_dict["fields"]:

View File

@@ -51,8 +51,7 @@ from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.utils.logger import setup_logger
@@ -208,8 +207,7 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_FILE: chunk.user_file if chunk.user_file is not None else None,
USER_FOLDER: chunk.user_folder if chunk.user_folder is not None else None,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -14,8 +14,7 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -138,6 +137,18 @@ def build_vespa_filters(
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
def _build_user_project_filter(
project_id: int | None,
) -> str:
if project_id is None:
return ""
try:
pid = int(project_id)
except Exception:
return ""
# Vespa YQL 'contains' expects a string literal; quote the integer
return f'({USER_PROJECT} contains "{pid}") and '
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
@@ -172,10 +183,14 @@ def build_vespa_filters(
# Document sets
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
# New: user_file_ids as integer filters
filter_str += _build_int_or_filters(USER_FILE, filters.user_file_ids)
# Convert UUIDs to strings for user_file_ids
user_file_ids_str = (
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
)
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
filter_str += _build_int_or_filters(USER_FOLDER, filters.user_folder_ids)
# User project filter (array<int> attribute membership)
filter_str += _build_user_project_filter(filters.project_id)
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)

View File

@@ -55,6 +55,7 @@ ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

Some files were not shown because too many files have changed in this diff Show More