Compare commits

..

36 Commits

Author SHA1 Message Date
pablonyx
39f98e288c quick nit 2025-03-31 16:28:38 -07:00
rkuo-danswer
9e2784e6e6 also set permission upsert to medium priority (#4405)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 16:28:38 -07:00
rkuo-danswer
80ca1f90c2 Bugfix/slack rate limiting (#4386)
* use slack's built in rate limit handler for the bot

* WIP

* fix the slack rate limit handler

* change default to 8

* cleanup

* try catch int conversion just in case

* linearize this logic better

* code review comments

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 16:28:38 -07:00
evan-danswer
cff7a854bf minor improvement to fireflies connector (#4383)
* minor improvement to fireflies connector

* reduce time diff
2025-03-31 16:28:38 -07:00
evan-danswer
ab99f9d3c0 ensure bedrock model contains API key (#4396)
* ensure bedrock model contains API key

* fix storing bug
2025-03-31 16:28:38 -07:00
pablonyx
a49b4dd531 add user files 2025-03-31 13:23:28 -07:00
pablonyx
04911db715 fix slashes (#4259) 2025-03-31 18:08:17 +00:00
rkuo-danswer
feae7d0cc4 disambiguate job name from ee version (#4403)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 11:48:28 -07:00
pablonyx
ac19c64b3c temporary fix for auth (#4402) 2025-03-31 11:10:41 -07:00
pablonyx
03d5c30fd2 fix (#4372) 2025-03-31 17:25:21 +00:00
joachim-danswer
e988c13e1d Additional logging for the path from Search Results to LLM Context (#4387)
* added logging

* nit

* nit
2025-03-31 00:38:43 +00:00
pablonyx
dc18d53133 Improve multi tenant anonymous user interaction (#3857)
* cleaner handling

* k

* k

* address nits

* fix typing
2025-03-31 00:33:32 +00:00
evan-danswer
a1cef389aa fallback to ignoring unicode chars when huggingface tokenizer fails (#4394) 2025-03-30 23:45:20 +00:00
pablonyx
db8d6ce538 formatting (#4316) 2025-03-30 23:43:17 +00:00
pablonyx
e8370dcb24 Update refresh conditional (#4375)
* update refresh conditional

* k
2025-03-30 17:28:35 -07:00
pablonyx
9951fe13ba Fix image input processing without LLMs (#4390)
* quick fix

* quick fix

* Revert "quick fix"

This reverts commit 906b29bd9b.

* nit
2025-03-30 19:28:49 +00:00
evan-danswer
56f8ab927b Contextual Retrieval (#4029)
* contextual rag implementation

* WIP

* indexing test fix

* workaround for chunking errors, WIP on fixing massive memory cost

* mypy and test fixes

* reformatting

* fixed rebase
2025-03-30 18:49:09 +00:00
rkuo-danswer
cb5bbd3812 Feature/mit integration tests (#4299)
* new mit integration test template

* edit

* fix problem with ACL type tags and MIT testing for test_connector_deletion

* fix test_connector_deletion_for_overlapping_connectors

* disable some enterprise only tests in MIT version

* disable a bunch of user group / curator tests in MIT version

* wire off more tests

* typo fix

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-30 02:41:08 +00:00
Yuhong Sun
742d29e504 Remove BETA 2025-03-29 15:38:46 -07:00
SubashMohan
ecc155d082 fix: ensure base_url ends with a trailing slash (#4388) 2025-03-29 14:34:30 -07:00
pablonyx
0857e4809d fix background color 2025-03-28 16:33:30 -07:00
Chris Weaver
22e00a1f5c Fix duplicate docs (#4378)
* Initial

* Fix duplicate docs

* Add tests

* Switch to list comprehension

* Fix test
2025-03-28 22:25:26 +00:00
Chris Weaver
0d0588a0c1 Remove OnyxContext (#4376)
* Remove OnyxContext

* Fix UT

* Fix tests v2
2025-03-28 12:39:51 -07:00
rkuo-danswer
aab777f844 Bugfix/acl prefix (#4377)
* fix acl prefixing

* increase timeout a tad

* block access to init'ing DocumentAccess directly, fix test to work with ee/MIT

* fix env var checks

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-28 05:52:35 +00:00
pablonyx
babbe7689a k (#4380) 2025-03-28 02:23:45 +00:00
evan-danswer
a123661c92 fixed shared folder issue (#4371)
* fixed shared folder issue

* fix existing tests

* default allow files shared with me for service account
2025-03-27 23:39:52 +00:00
pablonyx
c554889baf Fix actions link (#4374) 2025-03-27 16:39:35 -07:00
rkuo-danswer
f08fa878a6 refactor file extension checking and add test for blob s3 (#4369)
* refactor file extension checking and add test for blob s3

* code review

* fix checking ext

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 18:57:44 +00:00
pablonyx
d307534781 add some debug logging (#4328) 2025-03-27 11:49:32 -07:00
rkuo-danswer
6f54791910 adjust some vars in real time (#4365)
* adjust some vars in real time

* some sanity checking

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 17:30:08 +00:00
pablonyx
0d5497bb6b Add multi-tenant user invitation flow test (#4360) 2025-03-27 09:53:15 -07:00
Chris Weaver
7648627503 Save all logs + add log persistence to most Onyx-owned containers (#4368)
* Save all logs + add log persistence to most Onyx-owned containers

* Separate volumes for each container

* Small fixes
2025-03-26 22:25:39 -07:00
pablonyx
927554d5ca slight robustification (#4367) 2025-03-27 03:23:36 +00:00
pablonyx
7dcec6caf5 Fix session touching (#4363)
* fix session touching

* Revert "fix session touching"

This reverts commit c473d5c9a2.

* Revert "Revert "fix session touching""

This reverts commit 26a71d40b6.

* update

* quick nit
2025-03-27 01:18:46 +00:00
rkuo-danswer
036648146d possible fix for confluence query filter (#4280)
* possible fix for confluence query filter

* nuke the attachment filter query ... it doesn't work!

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:35:14 +00:00
rkuo-danswer
2aa4697ac8 permission sync runs so often that it starves out other tasks if run at high priority (#4364)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:22:53 +00:00
285 changed files with 15915 additions and 1731 deletions

View File

@@ -0,0 +1,209 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -9,6 +9,10 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}

View File

@@ -6,396 +6,419 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
}
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"model_server.main:app",
"--reload",
"--port",
"9000"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"onyx.main:app",
"--reload",
"--port",
"8080"
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
"presentation": {
"group": "2"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
],
"presentation": {
"group": "2"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3",
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}
}

View File

@@ -0,0 +1,50 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 9aadf32dfeb4
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "9aadf32dfeb4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@@ -0,0 +1,113 @@
"""add user files
Revision ID: 9aadf32dfeb4
Revises: 3781a5eb12cb
Create Date: 2025-01-26 16:08:21.551022
"""
from alembic import op
import sqlalchemy as sa
import datetime
# revision identifiers, used by Alembic.
revision = "9aadf32dfeb4"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
op.drop_table("persona__user_file")
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
is_public=is_public_anywhere,
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
)
return access_map

View File

@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -3,6 +3,8 @@ from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -66,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
3) The anonymous user cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -52,41 +52,55 @@ async def _get_tenant_id_from_request(
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
if token_data:
tenant_id_from_payload = token_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else None
)
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
if tenant_id and not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(
anonymous_user_cookie
)
tenant_id = anonymous_user_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
if not tenant_id or not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
return tenant_id
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")

View File

@@ -14,7 +14,6 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -56,25 +55,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -111,9 +91,6 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,7 +8,6 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -164,8 +163,6 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -220,4 +217,3 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -1,4 +1,5 @@
import asyncio
import concurrent.futures
import logging
import uuid
@@ -94,6 +95,7 @@ async def get_or_provision_tenant(
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
except Exception as e:
@@ -505,8 +507,24 @@ async def setup_tenant(tenant_id: str) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Run Alembic migrations in a completely isolated way
# This function runs in a separate thread with its own event loop context
def isolated_migration_runner():
# Reset contextvar to prevent it from being shared with the parent thread
# This ensures no shared asyncio primitives between threads
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
# Run migrations in this isolated thread
run_alembic_migrations(tenant_id)
finally:
# Clean up the contextvar
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# Use a dedicated ThreadPoolExecutor for complete isolation
# This prevents asyncio loop/lock sharing issues
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
loop = asyncio.get_running_loop()
await loop.run_in_executor(executor, isolated_migration_runner)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:

View File

@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Commit the transaction
db_session.commit()

BIN
backend/hello-vmlinux.bin Normal file

Binary file not shown.

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
return DocumentAccess.build(
doc_access = DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,6 +26,8 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -38,12 +40,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
return DocumentAccess.build(
user_emails=[],
user_groups=[],
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
@@ -55,19 +57,18 @@ def _get_access_for_documents(
db_session=db_session,
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
doc_access = {}
for document_id, user_emails, is_public in document_access_info:
doc_access[document_id] = DocumentAccess.build(
user_emails=[email for email in user_emails if email],
# MIT version will wipe all groups and external groups on update
user_groups=set(),
user_groups=[],
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -56,34 +56,46 @@ class DocExternalAccess:
)
@dataclass(frozen=True)
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
@classmethod
def build(
cls,
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
is_public=True,
)

View File

@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -24,7 +23,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -156,7 +156,6 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,7 +183,6 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,7 +57,6 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -59,9 +57,7 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -321,8 +321,10 @@ def dispatch_separated(
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
accumulated_tokens = ""
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
accumulated_tokens += cast(str, token.content)
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)

View File

@@ -360,7 +360,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
async def oauth_callback(

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -111,6 +111,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]

View File

@@ -174,6 +174,9 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@@ -285,5 +288,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
]
)

View File

@@ -14,7 +14,7 @@ logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files

View File

@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
@@ -63,6 +64,15 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-folder-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,

View File

@@ -389,6 +389,8 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -443,26 +445,35 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
connector_id=connector_id_to_delete,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -495,15 +506,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -553,7 +564,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
queued_upsert_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -640,7 +651,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
if member_str in queued_upsert_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,6 +17,7 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -63,6 +64,7 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
@@ -106,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -286,7 +289,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
# fill in the celery task id

View File

@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
payload.celery_task_id = result.id

View File

@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -364,6 +365,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
task_logger.warning("check_for_indexing - Starting")
tasks_created = 0
locked = False
@@ -401,7 +403,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
# 1/3: KICKOFF
@@ -428,7 +434,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
cc_pairs = fetch_connector_credential_pairs(
db_session, include_user_files=True
)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -447,12 +455,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
task_logger.info(
f"check_for_indexing - Skipping fenced connector: "
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
)
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -460,6 +474,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
)
continue
last_attempt = get_last_attempt_for_cc_pair(
@@ -473,7 +490,20 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
task_logger.info(
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
continue
else:
task_logger.info(
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
reindex = False
if search_settings_instance.status.is_current():
@@ -512,6 +542,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
else:
task_logger.info(
f"Failed to create indexing task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
)
lock_beat.reacquire()
@@ -1144,6 +1180,9 @@ def connector_indexing_proxy_task(
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,

View File

@@ -371,6 +371,7 @@ def should_index(
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
return False
# User can still manually create single indexing attempts via the UI for the
@@ -380,6 +381,9 @@ def should_index(
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
print(
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
)
return False
# When switching over models, always index at least once
@@ -388,19 +392,31 @@ def should_index(
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
)
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
)
return False
return True
@@ -412,6 +428,9 @@ def should_index(
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
print(
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
)
return False
if search_settings_instance.status.is_current():
@@ -424,11 +443,16 @@ def should_index(
return True
if connector.refresh_freq is None:
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
print(
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
)
return False
return True
@@ -508,6 +532,13 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# Determine which queue to use based on whether this is a user file
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
@@ -518,7 +549,7 @@ def try_creating_indexing_task(
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
queue=queue,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)

View File

@@ -6,6 +6,7 @@ from tenacity import wait_random_exponential
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
class RetryDocumentIndex:
@@ -52,11 +53,13 @@ class RetryDocumentIndex:
*,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> int:
return self.index.update_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=fields,
user_fields=user_fields,
)

View File

@@ -164,6 +164,7 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# there are still other cc_pair references to the doc, so just resync to Vespa

View File

@@ -0,0 +1,266 @@
import time
from typing import List
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_with_user_files,
)
from onyx.db.document import get_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.search_settings import get_active_search_settings
from onyx.db.user_documents import fetch_user_files_for_documents
from onyx.db.user_documents import fetch_user_folders_for_documents
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check for documents that need user file folder metadata updates.
This task fetches all connector credential pairs with user files, gets the documents
associated with them, and updates the user file and folder metadata in Vespa.
"""
time_start = time.monotonic()
r = get_redis_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_current_tenant() as db_session:
# Get all connector credential pairs that have user files
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
if not cc_pairs:
task_logger.info("No connector credential pairs with user files found")
return True
# Get all documents associated with these cc_pairs
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
if not document_ids:
task_logger.info(
"No documents found for connector credential pairs with user files"
)
return True
# Fetch current user file and folder IDs for these documents
doc_id_to_user_file_id = fetch_user_files_for_documents(
document_ids=document_ids, db_session=db_session
)
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
document_ids=document_ids, db_session=db_session
)
# Update Vespa metadata for each document
for doc_id in document_ids:
user_file_id = doc_id_to_user_file_id.get(doc_id)
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
if user_file_id is not None or user_folder_id is not None:
# Schedule a task to update the document metadata
update_user_file_folder_metadata.apply_async(
args=(doc_id,), # Use tuple instead of list for args
kwargs={
"tenant_id": tenant_id,
"user_file_id": user_file_id,
"user_folder_id": user_folder_id,
},
queue="vespa_metadata_sync",
)
task_logger.info(
f"Scheduled metadata updates for {len(document_ids)} documents. "
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
)
return True
except Exception as e:
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
return False
finally:
lock_beat.release()
def get_documents_for_cc_pairs(
cc_pairs: List[ConnectorCredentialPair], db_session: Session
) -> List[str]:
"""Get all document IDs associated with the given connector credential pairs."""
if not cc_pairs:
return []
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
# Query to get document IDs from DocumentByConnectorCredentialPair
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
doc_cc_pairs = (
db_session.query(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.filter(
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.id.in_(cc_pair_ids),
ConnectorCredentialPair.connector_id
== DocumentByConnectorCredentialPair.connector_id,
ConnectorCredentialPair.credential_id
== DocumentByConnectorCredentialPair.credential_id,
)
.exists()
)
.all()
)
return [doc_id for (doc_id,) in doc_cc_pairs]
@shared_task(
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def update_user_file_folder_metadata(
self: Task,
document_id: str,
*,
tenant_id: str,
user_file_id: int | None,
user_folder_id: int | None,
) -> bool:
"""Updates the user file and folder metadata for a document in Vespa."""
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
return False
# Create user fields object with file and folder IDs
user_fields = VespaDocumentUserFields(
user_file_id=str(user_file_id) if user_file_id is not None else None,
user_folder_id=str(user_folder_id)
if user_folder_id is not None
else None,
)
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=None, # We're only updating user fields
user_fields=user_fields,
)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=user_file_folder_sync "
f"user_file_id={user_file_id} "
f"user_folder_id={user_folder_id} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
return True
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
task_logger.exception(
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
)
return False

View File

@@ -80,7 +80,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
# Useful for debugging timing issues with reacquisitions.
# TODO: remove once more generalized logging is in place
task_logger.info("check_for_vespa_sync_task started")
time_start = time.monotonic()
@@ -572,6 +573,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# update db last. Worst case = we crash right before this and

View File

@@ -274,7 +274,6 @@ def _run_indexing(
"Search settings must be set for indexing. This should not be possible."
)
# search_settings = index_attempt_start.search_settings
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
ctx = RunIndexingContext(
@@ -638,6 +637,9 @@ def _run_indexing(
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -684,6 +686,9 @@ def _run_indexing(
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -746,6 +751,7 @@ def _run_indexing(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(

View File

@@ -127,6 +127,10 @@ class StreamStopInfo(SubQuestionIdentifier):
return data
class UserKnowledgeFilePacket(BaseModel):
user_files: list[FileDescriptor]
class LLMRelevanceFilterResponse(BaseModel):
llm_selected_doc_indices: list[int]
@@ -194,17 +198,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -270,7 +263,6 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -37,6 +36,7 @@ from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
@@ -52,6 +52,7 @@ from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
@@ -65,6 +66,7 @@ from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import attach_files_to_chat_message
from onyx.db.chat import create_db_search_doc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import create_search_doc_from_user_file
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_db_search_doc_by_id
@@ -73,6 +75,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -80,12 +83,16 @@ from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_persona_by_id
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_all_user_file_files
from onyx.file_store.utils import load_all_user_files
from onyx.file_store.utils import save_files
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
@@ -98,6 +105,7 @@ from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.utils import get_json_line
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_constructor import construct_tools
@@ -130,7 +138,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -176,11 +183,14 @@ def _handle_search_tool_response_summary(
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
dedupe_docs: bool = False,
user_files: list[UserFile] | None = None,
loaded_user_files: list[InMemoryChatFile] | None = None,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
is_extended = isinstance(packet, ExtendedToolResponse)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
@@ -194,9 +204,31 @@ def _handle_search_tool_response_summary(
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None:
for user_file in user_files:
if user_file.id not in doc_ids:
associated_chat_file = None
if loaded_user_files is not None:
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
@@ -254,7 +286,10 @@ def _handle_internet_search_tool_response_summary(
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
new_msg_req: CreateChatMessageRequest,
tools: list[Tool],
user_file_ids: list[int],
user_folder_ids: list[int],
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
@@ -262,8 +297,11 @@ def _get_force_search_settings(
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
if new_msg_req.force_user_file_search:
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
else:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
@@ -273,12 +311,25 @@ def _get_force_search_settings(
else None
)
# Create override_kwargs for the search tool if user_file_ids are provided
override_kwargs = None
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=False,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
)
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
should_force_search = any(
[
new_msg_req.force_user_file_search,
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
@@ -291,15 +342,22 @@ def _get_force_search_settings(
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
return ForceUseTool(
force_use=True,
tool_name=tool_name,
args=args,
override_kwargs=override_kwargs,
)
return ForceUseTool(
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
)
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -313,6 +371,7 @@ ChatPacket = (
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
| UserKnowledgeFilePacket
)
ChatPacketStream = Iterator[ChatPacket]
@@ -358,6 +417,10 @@ def stream_chat_message_objects(
llm: LLM
try:
# Move these variables inside the try block
file_id_to_user_file = {}
ordered_user_files = None
user_id = user.id if user is not None else None
chat_session = get_chat_session_by_id(
@@ -537,6 +600,70 @@ def stream_chat_message_objects(
)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
user_folder_ids = new_msg_req.user_folder_ids or []
if persona.user_files:
for file in persona.user_files:
user_file_ids.append(file.id)
if persona.user_folders:
for folder in persona.user_folders:
user_folder_ids.append(folder.id)
# Initialize flag for user file search
use_search_for_user_files = False
user_files: list[InMemoryChatFile] | None = None
search_for_ordering_only = False
user_file_files: list[UserFile] | None = None
if user_file_ids or user_folder_ids:
# Load user files
user_files = load_all_user_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
user_file_files = load_all_user_file_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Store mapping of file_id to file for later reordering
if user_files:
file_id_to_user_file = {file.file_id: file for file in user_files}
# Calculate token count for the files
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
total_tokens = calculate_user_files_token_count(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Calculate available tokens for documents based on prompt, user input, etc.
available_tokens = compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
actual_user_input=message_text, # Use the actual user message
)
logger.debug(
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
)
# ALWAYS use search for user files, but track if we need it for context or just ordering
use_search_for_user_files = True
# If files are small enough for context, we'll just use search for ordering
search_for_ordering_only = total_tokens <= available_tokens
if search_for_ordering_only:
# Add original user files to context since they fit
if user_files:
latest_query_files.extend(user_files)
if user_message:
attach_files_to_chat_message(
@@ -679,8 +806,10 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
db_session=db_session,
user=user,
user_knowledge_present=bool(user_files or user_folder_ids),
llm=llm,
fast_llm=fast_llm,
use_file_search=new_msg_req.force_user_file_search,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
@@ -710,17 +839,138 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
force_use_tool = _get_force_search_settings(
new_msg_req, tools, user_file_ids, user_folder_ids
)
# Set force_use if user files exceed token limit
if use_search_for_user_files:
try:
# Check if search tool is available in the tools list
search_tool_available = any(
isinstance(tool, SearchTool) for tool in tools
)
# If no search tool is available, add one
if not search_tool_available:
logger.info("No search tool available, creating one for user files")
# Create a basic search tool config
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
)
# Create and add the search tool
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
bypass_acl=bypass_acl,
)
# Add the search tool to the tools list
tools.append(search_tool)
logger.info(
"Added search tool for user files that exceed token limit"
)
# Now set force_use_tool.force_use to True
force_use_tool.force_use = True
force_use_tool.tool_name = SearchTool._NAME
# Set query argument if not already set
if not force_use_tool.args:
force_use_tool.args = {"query": final_msg.message}
# Pass the user file IDs to the search tool
if user_file_ids or user_folder_ids:
# Create a BaseFilters object with user_file_ids
if not retrieval_options:
retrieval_options = RetrievalDetails()
if not retrieval_options.filters:
retrieval_options.filters = BaseFilters()
# Set user file and folder IDs in the filters
retrieval_options.filters.user_file_ids = user_file_ids
retrieval_options.filters.user_folder_ids = user_folder_ids
# Create override kwargs for the search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
)
# Set the override kwargs in the force_use_tool
force_use_tool.override_kwargs = override_kwargs
if search_for_ordering_only:
logger.info(
"Fast path: Configured search tool with optimized settings for ordering-only"
)
logger.info(
"Fast path: Skipping reranking and query analysis for ordering-only mode"
)
logger.info(
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
)
else:
logger.info(
"Configured search tool to use ",
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
)
except Exception as e:
logger.exception(
f"Error configuring search tool for user files: {str(e)}"
)
use_search_for_user_files = False
# TODO: unify message history with single message history
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
if not use_search_for_user_files and user_files:
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
)
for file in user_files
]
)
if search_for_ordering_only:
logger.info(
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
)
search_request = SearchRequest(
query=final_msg.message,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
LLMEvaluationType.SKIP
if search_for_ordering_only
else (
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
)
),
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
@@ -739,7 +989,6 @@ def stream_chat_message_objects(
),
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
@@ -808,8 +1057,22 @@ def stream_chat_message_objects(
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
continue
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
@@ -819,16 +1082,91 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
False
if search_for_ordering_only
else (
retrieval_options.dedupe_docs
if retrieval_options
else False
)
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=user_files
if search_for_ordering_only
else [],
)
# If we're using search just for ordering user files
if (
search_for_ordering_only
and user_files
and info.qa_docs_response
):
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
import time
ordering_start = time.time()
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(
f"ORDERING: Found {len(doc_order)} files from search results"
)
# Add any files that weren't in search results at the end
missing_files = [
f_id
for f_id in file_id_to_user_file.keys()
if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(
f"ORDERING: Added {len(missing_files)} missing files to the end"
)
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id]
for f_id in doc_order
if f_id in file_id_to_user_file
]
time.time() - ordering_start
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
continue
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
@@ -918,8 +1256,6 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
@@ -940,7 +1276,7 @@ def stream_chat_message_objects(
]
info.tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -1022,10 +1358,16 @@ def stream_chat_message_objects(
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
tool_name=info.tool_result.tool_name,
tool_arguments=info.tool_result.tool_args,
tool_result=info.tool_result.tool_result,
tool_id=tool_name_to_tool_id.get(info.tool_result.tool_name, 0)
if info.tool_result
else None,
tool_name=info.tool_result.tool_name if info.tool_result else None,
tool_arguments=info.tool_result.tool_args
if info.tool_result
else None,
tool_result=info.tool_result.tool_result
if info.tool_result
else None,
)
if info.tool_result
else None
@@ -1069,6 +1411,8 @@ def stream_chat_message_objects(
prev_message = next_answer_message
logger.debug("Committing messages")
# Explicitly update the timestamp on the chat session
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)

View File

@@ -19,6 +19,7 @@ def translate_onyx_msg_to_langchain(
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)

View File

@@ -153,6 +153,8 @@ def _apply_pruning(
# remove docs that are explicitly marked as not for QA
sections = _remove_sections_to_ignore(sections=sections)
section_idx_token_count: dict[int, int] = {}
final_section_ind = None
total_tokens = 0
for ind, section in enumerate(sections):
@@ -202,10 +204,20 @@ def _apply_pruning(
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_token_count
section_idx_token_count[ind] = section_token_count
if total_tokens > token_limit:
final_section_ind = ind
break
try:
logger.debug(f"Number of documents after pruning: {ind + 1}")
logger.debug("Number of tokens per document (pruned):")
for x, y in section_idx_token_count.items():
logger.debug(f"{x + 1}: {y}")
except Exception as e:
logger.error(f"Error logging prune statistics: {e}")
if final_section_ind is not None:
if is_manually_selected_docs or use_sections:
if final_section_ind != len(sections) - 1:
@@ -301,6 +313,10 @@ def prune_sections(
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -358,6 +374,26 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
reverse=True,
)
try:
num_original_sections = len(sections)
num_original_document_ids = len(
set([section.center_chunk.document_id for section in sections])
)
num_merged_sections = len(new_sections)
num_merged_document_ids = len(
set([section.center_chunk.document_id for section in new_sections])
)
logger.debug(
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
)
logger.debug("Number of chunks per document (new ranking):")
for x, y in enumerate(new_sections):
logger.debug(f"{x + 1}: {len(y.chunks)}")
except Exception as e:
logger.error(f"Error logging merge statistics: {e}")
return new_sections

View File

@@ -3,7 +3,6 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -180,6 +180,10 @@ def get_tool_call_for_non_tool_calling_llm_impl(
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
# If we have override_kwargs, add them to the tool_args
if force_use_tool.override_kwargs is not None:
tool_args["override_kwargs"] = force_use_tool.override_kwargs
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(

View File

@@ -170,7 +170,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
@@ -437,7 +437,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
@@ -495,6 +495,11 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Enable contextual retrieval
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
@@ -536,6 +541,17 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
# Average summary embeddings for contextual rag (not yet implemented)
AVERAGE_SUMMARY_EMBEDDINGS = (
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
)
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Miscellaneous
#####

View File

@@ -3,7 +3,7 @@ import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page

View File

@@ -102,6 +102,8 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -269,6 +271,7 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
@@ -309,6 +312,7 @@ class OnyxCeleryQueues:
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Monitoring queue
MONITORING = "monitoring"
@@ -327,6 +331,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
@@ -382,6 +387,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:
@@ -396,6 +402,7 @@ class OnyxCeleryTask:
# Tenant pre-provisioning
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -404,6 +411,7 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
raise ConnectorMissingCredentialError("Amazon S3")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -65,20 +65,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
ONE_HOUR = 3600
@@ -209,7 +195,6 @@ class ConfluenceConnector(
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -374,11 +359,13 @@ class ConfluenceConnector(
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,

View File

@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -69,7 +70,9 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_valid_file_ext(extension):
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -122,7 +123,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)

View File

@@ -45,6 +45,8 @@ _FIREFLIES_API_QUERY = """
}
"""
ONE_MINUTE = 60
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
@@ -106,6 +108,8 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
)
# If not all transcripts are being indexed, try using a more-recently-generated
# API key.
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -191,6 +195,9 @@ class FirefliesConnector(PollConnector, LoadConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
# add some leeway to account for any timezone funkiness and/or bad handling
# of start time on the Fireflies side
start = max(0, start - ONE_MINUTE)
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)

View File

@@ -276,7 +276,26 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
return checkpoint
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try direct access to a known attribute name first
if hasattr(self.github_client, "_requester"):
requester = self.github_client._requester
elif hasattr(self.github_client, "_Github__requester"):
requester = self.github_client._Github__requester
else:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
repo = checkpoint.cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")

View File

@@ -28,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -86,13 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -450,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive(
get_all_files_in_my_drive_and_shared(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -916,20 +924,28 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[GoogleDriveFileType],
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
@@ -967,7 +983,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
continue
files_batch.append(retrieved_file.drive_file)
files_batch.append(retrieved_file)
if len(files_batch) < self.batch_size:
continue

View File

@@ -30,6 +30,7 @@ from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -76,6 +77,26 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
@@ -87,35 +108,17 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -124,88 +127,97 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
response_call = lazy_eval(lambda: download_request(service, file_id))
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def convert_drive_item_to_document(

View File

@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive(
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -75,7 +75,7 @@ class HighspotClient:
self.key = key
self.secret = secret
self.base_url = base_url
self.base_url = base_url.rstrip("/") + "/"
self.timeout = timeout
# Set up session with retry logic

View File

@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and file_extension in VALID_FILE_EXTENSIONS
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content

View File

@@ -163,6 +163,9 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def get_text_content(self) -> str:
return " ".join([section.text for section in self.sections if section.text])
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""

View File

@@ -14,6 +14,8 @@ from typing import cast
from pydantic import BaseModel
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
@@ -26,6 +28,8 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
@@ -38,15 +42,16 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
@@ -493,9 +498,13 @@ def _process_message(
)
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
class SlackConnector(
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
):
FAST_TIMEOUT = 1
MAX_RETRIES = 7 # arbitrarily selected
def __init__(
self,
channels: list[str] | None = None,
@@ -514,16 +523,49 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
self.credentials_provider: CredentialsProviderInterface | None = None
self.credential_prefix: str | None = None
self.delay_lock: str | None = None # the redis key for the shared lock
self.delay_key: str | None = None # the redis key for the shared delay
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
credentials = credentials_provider.get_credentials()
tenant_id = credentials_provider.get_tenant_id()
self.redis = get_redis_client(tenant_id=tenant_id)
self.credential_prefix = (
f"connector:slack:credential_{credentials_provider.get_provider_key()}"
)
self.delay_lock = f"{self.credential_prefix}:delay_lock"
self.delay_key = f"{self.credential_prefix}:delay"
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
connection_error_retry_handler = ConnectionErrorRetryHandler()
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
max_retry_count=self.MAX_RETRIES,
delay_lock=self.delay_lock,
delay_key=self.delay_key,
r=self.redis,
)
custom_retry_handlers: list[RetryHandler] = [
connection_error_retry_handler,
onyx_rate_limit_error_retry_handler,
]
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers)
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
self.fast_client = WebClient(
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
self.credentials_provider = credentials_provider
def retrieve_all_slim_documents(
self,

View File

@@ -0,0 +1,159 @@
import math
import random
import time
from typing import cast
from typing import Optional
from redis import Redis
from redis.lock import Lock as RedisLock
from slack_sdk.http_retry.handler import RetryHandler
from slack_sdk.http_retry.request import HttpRequest
from slack_sdk.http_retry.response import HttpResponse
from slack_sdk.http_retry.state import RetryState
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxRedisSlackRetryHandler(RetryHandler):
"""
This class uses Redis to share a rate limit among multiple threads.
Threads that encounter a rate limit will observe the shared delay, increment the
shared delay with the retry value, and use the new shared value as a wait interval.
This has the effect of serializing calls when a rate limit is hit, which is what
needs to happens if the server punishes us with additional limiting when we make
a call too early. We believe this is what Slack is doing based on empirical
observation, meaning we see indefinite hangs if we're too aggressive.
Another way to do this is just to do exponential backoff. Might be easier?
Adapted from slack's RateLimitErrorRetryHandler.
"""
LOCK_TTL = 60 # used to serialize access to the retry TTL
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
"""RetryHandler that does retries for rate limited errors."""
def __init__(
self,
max_retry_count: int,
delay_lock: str,
delay_key: str,
r: Redis,
):
"""
delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key)
delay_key: the redis key containing a shared TTL
"""
super().__init__(max_retry_count=max_retry_count)
self._redis: Redis = r
self._delay_lock = delay_lock
self._delay_key = delay_key
def _can_retry(
self,
*,
state: RetryState,
request: HttpRequest,
response: Optional[HttpResponse] = None,
error: Optional[Exception] = None,
) -> bool:
return response is not None and response.status_code == 429
def prepare_for_next_attempt(
self,
*,
state: RetryState,
request: HttpRequest,
response: Optional[HttpResponse] = None,
error: Optional[Exception] = None,
) -> None:
"""It seems this function is responsible for the wait to retry ... aka we
actually sleep in this function."""
retry_after_value: list[str] | None = None
retry_after_header_name: Optional[str] = None
duration_s: float = 1.0 # seconds
if response is None:
# NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler.
# This reads oddly, as if the caller itself could raise the exception.
# We don't have the luxury of changing this.
if error:
raise error
return
state.next_attempt_requested = True # this signals the caller to retry
# calculate wait duration based on retry-after + some jitter
for k in response.headers.keys():
if k.lower() == "retry-after":
retry_after_header_name = k
break
try:
if retry_after_header_name is None:
# This situation usually does not arise. Just in case.
raise ValueError(
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
)
retry_after_value = response.headers.get(retry_after_header_name)
if not retry_after_value:
raise ValueError(
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
)
retry_after_value_int = int(
retry_after_value[0]
) # will raise ValueError if somehow we can't convert to int
jitter = retry_after_value_int * 0.25 * random.random()
duration_s = math.ceil(retry_after_value_int + jitter)
except ValueError:
duration_s += random.random()
# lock and extend the ttl
lock: RedisLock = self._redis.lock(
self._delay_lock,
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
thread_local=False,
)
acquired = lock.acquire(
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
)
ttl_ms: int | None = None
try:
if acquired:
# if we can get the lock, then read and extend the ttl
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
if ttl_ms < 0: # negative values are error status codes ... see docs
ttl_ms = 0
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
else:
# if we can't get the lock, just go ahead.
# TODO: if we know our actual parallelism, multiplying by that
# would be a pretty good idea
ttl_ms_new = int(duration_s * 1000.0)
finally:
if acquired:
lock.release()
logger.warning(
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
f"retry-after={retry_after_value} "
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
)
# TODO: would be good to take an event var and sleep in short increments to
# allow for a clean exit / exception
time.sleep(ttl_ms_new / 1000.0)
state.increment_current_attempt()

View File

@@ -1,5 +1,4 @@
import re
import time
from collections.abc import Callable
from collections.abc import Generator
from functools import lru_cache
@@ -64,71 +63,72 @@ def _make_slack_api_call_paginated(
return paginated_call
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 7
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as
# expected. Do we want to keep this around?
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
# def make_slack_api_rate_limited(
# call: Callable[..., SlackResponse], max_retries: int = 7
# ) -> Callable[..., SlackResponse]:
# """Wraps calls to slack API so that they automatically handle rate limiting"""
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# @wraps(call)
# def rate_limited_call(**kwargs: Any) -> SlackResponse:
# last_exception = None
# Check for errors in the response, will raise `SlackApiError`
# if anything went wrong
response.validate()
return response
# for _ in range(max_retries):
# try:
# # Make the API call
# response = call(**kwargs)
except SlackApiError as e:
last_exception = e
try:
error = e.response["error"]
except KeyError:
error = "unknown error"
# # Check for errors in the response, will raise `SlackApiError`
# # if anything went wrong
# response.validate()
# return response
if error == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.info(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction", "internal_error"]:
# Log internal_error and return the response instead of failing
logger.warning(
f"Slack call encountered '{error}', skipping and continuing..."
)
return e.response
else:
# Raise the error for non-transient errors
raise
# except SlackApiError as e:
# last_exception = e
# try:
# error = e.response["error"]
# except KeyError:
# error = "unknown error"
# If the code reaches this point, all retries have been exhausted
msg = f"Max retries ({max_retries}) exceeded"
if last_exception:
raise Exception(msg) from last_exception
else:
raise Exception(msg)
# if error == "ratelimited":
# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
# retry_after = int(e.response.headers.get("Retry-After", 1))
# logger.info(
# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
# )
# time.sleep(retry_after)
# elif error in ["already_reacted", "no_reaction", "internal_error"]:
# # Log internal_error and return the response instead of failing
# logger.warning(
# f"Slack call encountered '{error}', skipping and continuing..."
# )
# return e.response
# else:
# # Raise the error for non-transient errors
# raise
return rate_limited_call
# # If the code reaches this point, all retries have been exhausted
# msg = f"Max retries ({max_retries}) exceeded"
# if last_exception:
# raise Exception(msg) from last_exception
# else:
# raise Exception(msg)
# return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
return basic_retry_wrapper(call)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(make_slack_api_rate_limited(call))
)(**kwargs)
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
def expert_info_from_slack_id(
@@ -142,7 +142,7 @@ def expert_info_from_slack_id(
if user_id in user_cache:
return user_cache[user_id]
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
response = client.users_info(user=user_id)
if not response["ok"]:
user_cache[user_id] = None
@@ -175,9 +175,7 @@ class SlackTextCleaner:
def _get_slack_name(self, user_id: str) -> str:
if user_id not in self._id_to_name_map:
try:
response = make_slack_api_rate_limited(self._client.users_info)(
user=user_id
)
response = self._client.users_info(user=user_id)
# prefer display name if set, since that is what is shown in Slack
self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"]

View File

@@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
inference_settings = InferenceSettings.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@@ -102,6 +105,8 @@ class BaseFilters(BaseModel):
document_set: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
class IndexFilters(BaseFilters):
@@ -218,6 +223,8 @@ class InferenceChunk(BaseChunk):
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
doc_summary: str
chunk_context: str
# when the doc was last updated
updated_at: datetime | None

View File

@@ -158,6 +158,47 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
def get_ordering_only_chunks(
self,
query: str,
user_file_ids: list[int] | None = None,
user_folder_ids: list[int] | None = None,
) -> list[InferenceChunk]:
"""Optimized method that only retrieves chunks for ordering purposes.
Skips all extra processing and uses minimal configuration to speed up retrieval.
"""
logger.info("Fast path: Using optimized chunk retrieval for ordering-only mode")
# Create minimal filters with just user file/folder IDs
filters = IndexFilters(
user_file_ids=user_file_ids or [],
user_folder_ids=user_folder_ids or [],
access_control_list=None,
)
# Use a simplified query that skips all unnecessary processing
minimal_query = SearchQuery(
query=query,
search_type=SearchType.SEMANTIC,
filters=filters,
# Set minimal options needed for retrieval
evaluation_type=LLMEvaluationType.SKIP,
recency_bias_multiplier=1.0,
chunks_above=0, # No need for surrounding context
chunks_below=0, # No need for surrounding context
processed_keywords=[], # Empty list instead of None
rerank_settings=None,
hybrid_alpha=0.0,
max_llm_filter_sections=0,
)
# Retrieve chunks using the minimal configuration
return retrieve_chunks(
query=minimal_query,
document_index=self.document_index,
db_session=self.db_session,
)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
@@ -339,6 +380,12 @@ class SearchPipeline:
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
"""Should be used to display in the UI in order to prevent displaying
multiple sections for the same document as separate "documents"."""
return _merge_sections(sections=self.retrieved_sections)
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -385,6 +432,10 @@ class SearchPipeline:
self.search_query.evaluation_type == LLMEvaluationType.SKIP
or DISABLE_LLM_DOC_RELEVANCE
):
if self.search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info(
"Fast path: Skipping section relevance evaluation for ordering-only mode"
)
return None
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
@@ -415,6 +466,10 @@ class SearchPipeline:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
# since the property sets the generator. DO NOT REMOVE.
_ = self.final_context_sections
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
@@ -196,9 +197,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
RETURN_SEPARATOR
)
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
# remove document summary
if chunk.content.startswith(chunk.doc_summary):
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
# remove chunk context
if chunk.content.endswith(chunk.chunk_context):
chunk.content = chunk.content[
: len(chunk.content) - len(chunk.chunk_context)
].rstrip()
return chunk.content
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
chunk.content = _remove_contextual_rag(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]
@@ -354,6 +367,21 @@ def filter_sections(
Returns a list of the unique chunk IDs that were marked as relevant
"""
# Log evaluation type to help with debugging
logger.info(f"filter_sections called with evaluation_type={query.evaluation_type}")
# Fast path: immediately return empty list for SKIP evaluation type (ordering-only mode)
if query.evaluation_type == LLMEvaluationType.SKIP:
return []
# Additional safeguard: Log a warning if this function is ever called with SKIP evaluation type
# This should never happen if our fast paths are working correctly
if query.evaluation_type == LLMEvaluationType.SKIP:
logger.warning(
"WARNING: filter_sections called with SKIP evaluation_type. This should never happen!"
)
return []
sections_to_filter = sections_to_filter[: query.max_llm_filter_sections]
contents = [
@@ -386,6 +414,16 @@ def search_postprocessing(
llm: LLM,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Iterator[list[InferenceSection] | list[SectionRelevancePiece]]:
# Fast path for ordering-only: detect it by checking if evaluation_type is SKIP
if search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info(
"Fast path: Detected ordering-only mode, bypassing all post-processing"
)
# Immediately yield the sections without any processing and an empty relevance list
yield retrieved_sections
yield cast(list[SectionRelevancePiece], [])
return
post_processing_tasks: list[FunctionCall] = []
if not retrieved_sections:
@@ -422,10 +460,14 @@ def search_postprocessing(
sections_yielded = True
llm_filter_task_id = None
if search_query.evaluation_type in [
LLMEvaluationType.BASIC,
LLMEvaluationType.UNSPECIFIED,
]:
# Only add LLM filtering if not in SKIP mode and if LLM doc relevance is not disabled
if (
search_query.evaluation_type not in [LLMEvaluationType.SKIP]
and not DISABLE_LLM_DOC_RELEVANCE
and search_query.evaluation_type
in [LLMEvaluationType.BASIC, LLMEvaluationType.UNSPECIFIED]
):
logger.info("Adding LLM filtering task for document relevance evaluation")
post_processing_tasks.append(
FunctionCall(
filter_sections,
@@ -437,6 +479,10 @@ def search_postprocessing(
)
)
llm_filter_task_id = post_processing_tasks[-1].result_id
elif search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info("Fast path: Skipping LLM filtering task for ordering-only mode")
elif DISABLE_LLM_DOC_RELEVANCE:
logger.info("Skipping LLM filtering task because LLM doc relevance is disabled")
post_processing_results = (
run_functions_in_parallel(post_processing_tasks)

View File

@@ -165,7 +165,18 @@ def retrieval_preprocessing(
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
user_file_ids = preset_filters.user_file_ids or []
user_folder_ids = preset_filters.user_folder_ids or []
if persona and persona.user_files:
user_file_ids = user_file_ids + [
file.id
for file in persona.user_files
if file.id not in (preset_filters.user_file_ids or [])
]
final_filters = IndexFilters(
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=time_filter or predicted_time_cutoff,

View File

@@ -26,6 +26,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
from onyx.auth.schemas import UserRole
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDocs
@@ -44,9 +45,11 @@ from onyx.db.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_best_persona_id_for_user
from onyx.db.pg_file_store import delete_lobj_by_name
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
@@ -854,6 +857,87 @@ def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | N
return search_doc
def create_search_doc_from_user_file(
db_user_file: UserFile, associated_chat_file: InMemoryChatFile, db_session: Session
) -> SearchDoc:
"""Create a SearchDoc in the database from a UserFile and return it.
This ensures proper ID generation by SQLAlchemy and prevents duplicate key errors.
"""
blurb = ""
if associated_chat_file and associated_chat_file.content:
try:
# Try to decode as UTF-8, but handle errors gracefully
content_sample = associated_chat_file.content[:100]
# Remove null bytes which can cause SQL errors
content_sample = content_sample.replace(b"\x00", b"")
blurb = content_sample.decode("utf-8", errors="replace")
except Exception:
# If decoding fails completely, provide a generic description
blurb = f"[Binary file: {db_user_file.name}]"
db_search_doc = SearchDoc(
document_id=db_user_file.document_id,
chunk_ind=0, # Default to 0 for user files
semantic_id=db_user_file.name,
link=db_user_file.link_url,
blurb=blurb,
source_type=DocumentSource.FILE, # Assuming internal source for user files
boost=0, # Default boost
hidden=False, # Default visibility
doc_metadata={}, # Empty metadata
score=0.0, # Default score of 0.0 instead of None
is_relevant=None, # No relevance initially
relevance_explanation=None, # No explanation initially
match_highlights=[], # No highlights initially
updated_at=db_user_file.created_at, # Use created_at as updated_at
primary_owners=[], # Empty list instead of None
secondary_owners=[], # Empty list instead of None
is_internet=False, # Not from internet
)
db_session.add(db_search_doc)
db_session.flush() # Get the ID but don't commit yet
return db_search_doc
def translate_db_user_file_to_search_doc(
db_user_file: UserFile, associated_chat_file: InMemoryChatFile
) -> SearchDoc:
blurb = ""
if associated_chat_file and associated_chat_file.content:
try:
# Try to decode as UTF-8, but handle errors gracefully
content_sample = associated_chat_file.content[:100]
# Remove null bytes which can cause SQL errors
content_sample = content_sample.replace(b"\x00", b"")
blurb = content_sample.decode("utf-8", errors="replace")
except Exception:
# If decoding fails completely, provide a generic description
blurb = f"[Binary file: {db_user_file.name}]"
return SearchDoc(
# Don't set ID - let SQLAlchemy auto-generate it
document_id=db_user_file.document_id,
chunk_ind=0, # Default to 0 for user files
semantic_id=db_user_file.name,
link=db_user_file.link_url,
blurb=blurb,
source_type=DocumentSource.FILE, # Assuming internal source for user files
boost=0, # Default boost
hidden=False, # Default visibility
doc_metadata={}, # Empty metadata
score=0.0, # Default score of 0.0 instead of None
is_relevant=None, # No relevance initially
relevance_explanation=None, # No explanation initially
match_highlights=[], # No highlights initially
updated_at=db_user_file.created_at, # Use created_at as updated_at
primary_owners=[], # Empty list instead of None
secondary_owners=[], # Empty list instead of None
is_internet=False, # Not from internet
)
def translate_db_search_doc_to_server_search_doc(
db_search_doc: SearchDoc,
remove_doc_content: bool = False,
@@ -1089,3 +1173,20 @@ def log_agent_sub_question_results(
db_session.commit()
return None
def update_chat_session_updated_at_timestamp(
chat_session_id: UUID, db_session: Session
) -> None:
"""
Explicitly update the timestamp on a chat session without modifying other fields.
This is useful when adding messages to a chat session to reflect recent activity.
"""
# Direct SQL update to avoid loading the entire object if it's not already loaded
db_session.execute(
update(ChatSession)
.where(ChatSession.id == chat_session_id)
.values(time_updated=func.now())
)
# No commit - the caller is responsible for committing the transaction

View File

@@ -27,6 +27,7 @@ from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserFile
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.server.models import StatusResponse
@@ -106,11 +107,13 @@ def get_connector_credential_pairs_for_user(
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
include_user_files: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
@@ -126,6 +129,9 @@ def get_connector_credential_pairs_for_user(
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
return list(db_session.scalars(stmt).unique().all())
@@ -153,14 +159,16 @@ def get_connector_credential_pairs_for_user_parallel(
def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
db_session: Session, ids: list[int] | None = None, include_user_files: bool = False
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
return list(db_session.scalars(stmt).all())
@@ -207,12 +215,15 @@ def get_connector_credential_pair_for_user(
connector_id: int,
credential_id: int,
user: User | None,
include_user_files: bool = False,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
result = db_session.execute(stmt)
return result.scalar_one_or_none()
@@ -321,6 +332,9 @@ def _update_connector_credential_pair(
cc_pair.total_docs_indexed += net_docs
if status is not None:
cc_pair.status = status
if cc_pair.is_user_file:
cc_pair.status = ConnectorCredentialPairStatus.PAUSED
db_session.commit()
@@ -446,6 +460,7 @@ def add_credential_to_connector(
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
is_user_file: bool = False,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
@@ -511,6 +526,7 @@ def add_credential_to_connector(
access_type=access_type,
auto_sync_options=auto_sync_options,
last_successful_index_time=last_successful_index_time,
is_user_file=is_user_file,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
@@ -587,8 +603,12 @@ def remove_credential_from_connector(
def fetch_connector_credential_pairs(
db_session: Session,
include_user_files: bool = False,
) -> list[ConnectorCredentialPair]:
return db_session.query(ConnectorCredentialPair).all()
stmt = select(ConnectorCredentialPair)
if not include_user_files:
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
return list(db_session.scalars(stmt).unique().all())
def resync_cc_pair(
@@ -634,3 +654,23 @@ def resync_cc_pair(
)
db_session.commit()
def get_connector_credential_pairs_with_user_files(
db_session: Session,
) -> list[ConnectorCredentialPair]:
"""
Get all connector credential pairs that have associated user files.
Args:
db_session: Database session
Returns:
List of ConnectorCredentialPair objects that have user files
"""
return (
db_session.query(ConnectorCredentialPair)
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
.distinct()
.all()
)

View File

@@ -605,7 +605,6 @@ def fetch_document_sets_for_document(
result = fetch_document_sets_for_documents([document_id], db_session)
if not result:
return []
return result[0][1]

View File

@@ -212,6 +212,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
back_populates="creator",
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder", back_populates="user"
)
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
@validates("email")
def validate_email(self, key: str, value: str) -> str:
@@ -419,6 +423,7 @@ class ConnectorCredentialPair(Base):
"""
__tablename__ = "connector_credential_pair"
is_user_file: Mapped[bool] = mapped_column(Boolean, default=False)
# NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True`
# due to some SQLAlchemy quirks + this not being a primary key column
id: Mapped[int] = mapped_column(
@@ -505,6 +510,10 @@ class ConnectorCredentialPair(Base):
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
user_file: Mapped["UserFile"] = relationship(
"UserFile", back_populates="cc_pair", uselist=False
)
background_errors: Mapped[list["BackgroundError"]] = relationship(
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
)
@@ -791,6 +800,15 @@ class SearchSettings(Base):
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
# Contextual RAG
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
# Contextual RAG LLM
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
String, nullable=True
)
multilingual_expansion: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), default=[]
)
@@ -1799,6 +1817,17 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
# Relationship to UserFile
user_files: Mapped[list["UserFile"]] = relationship(
"UserFile",
secondary="persona__user_file",
back_populates="assistants",
)
user_folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder",
secondary="persona__user_folder",
back_populates="assistants",
)
labels: Mapped[list["PersonaLabel"]] = relationship(
"PersonaLabel",
secondary=Persona__PersonaLabel.__table__,
@@ -1815,6 +1844,24 @@ class Persona(Base):
)
class Persona__UserFolder(Base):
__tablename__ = "persona__user_folder"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_folder_id: Mapped[int] = mapped_column(
ForeignKey("user_folder.id"), primary_key=True
)
class Persona__UserFile(Base):
__tablename__ = "persona__user_file"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_file_id: Mapped[int] = mapped_column(
ForeignKey("user_file.id"), primary_key=True
)
class PersonaLabel(Base):
__tablename__ = "persona_label"
@@ -2337,6 +2384,64 @@ class InputPrompt__User(Base):
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
class UserFolder(Base):
__tablename__ = "user_folder"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
description: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped["User"] = relationship(back_populates="folders")
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
assistants: Mapped[list["Persona"]] = relationship(
"Persona",
secondary=Persona__UserFolder.__table__,
back_populates="user_folders",
)
class UserDocument(str, Enum):
CHAT = "chat"
RECENT = "recent"
FILE = "file"
class UserFile(Base):
__tablename__ = "user_file"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
assistants: Mapped[list["Persona"]] = relationship(
"Persona",
secondary=Persona__UserFile.__table__,
back_populates="user_files",
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
file_id: Mapped[str] = mapped_column(nullable=False)
document_id: Mapped[str] = mapped_column(nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
default=datetime.datetime.utcnow
)
user: Mapped["User"] = relationship(back_populates="files")
folder: Mapped["UserFolder"] = relationship(back_populates="files")
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=True, unique=True
)
cc_pair: Mapped["ConnectorCredentialPair"] = relationship(
"ConnectorCredentialPair", back_populates="user_file"
)
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
"""
Multi-tenancy related tables
"""

View File

@@ -33,6 +33,8 @@ from onyx.db.models import StarterMessage
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
@@ -209,7 +211,6 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
is_default_persona: bool | None = create_persona_request.is_default_persona
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
@@ -221,7 +222,7 @@ def create_update_persona(
user.role == UserRole.CURATOR
or user.role == UserRole.GLOBAL_CURATOR
):
is_default_persona = None
pass
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
@@ -249,7 +250,9 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=is_default_persona,
is_default_persona=create_persona_request.is_default_persona,
user_file_ids=create_persona_request.user_file_ids,
user_folder_ids=create_persona_request.user_folder_ids,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -344,6 +347,8 @@ def get_personas_for_user(
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
selectinload(Persona.user_files),
selectinload(Persona.user_folders),
)
results = db_session.execute(stmt).scalars().all()
@@ -438,6 +443,8 @@ def upsert_persona(
builtin_persona: bool = False,
is_default_persona: bool | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[int] | None = None,
user_folder_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -463,6 +470,7 @@ def upsert_persona(
user=user,
get_editable=True,
)
# Fetch and attach tools by IDs
tools = None
if tool_ids is not None:
@@ -481,6 +489,26 @@ def upsert_persona(
if not document_sets and document_set_ids:
raise ValueError("document_sets not found")
# Fetch and attach user_files by IDs
user_files = None
if user_file_ids is not None:
user_files = (
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
)
if not user_files and user_file_ids:
raise ValueError("user_files not found")
# Fetch and attach user_folders by IDs
user_folders = None
if user_folder_ids is not None:
user_folders = (
db_session.query(UserFolder)
.filter(UserFolder.id.in_(user_folder_ids))
.all()
)
if not user_folders and user_folder_ids:
raise ValueError("user_folders not found")
# Fetch and attach prompts by IDs
prompts = None
if prompt_ids is not None:
@@ -549,6 +577,14 @@ def upsert_persona(
if tools is not None:
existing_persona.tools = tools or []
if user_file_ids is not None:
existing_persona.user_files.clear()
existing_persona.user_files = user_files or []
if user_folder_ids is not None:
existing_persona.user_folders.clear()
existing_persona.user_folders = user_folders or []
# We should only update display priority if it is not already set
if existing_persona.display_priority is None:
existing_persona.display_priority = display_priority
@@ -590,6 +626,8 @@ def upsert_persona(
is_default_persona=is_default_persona
if is_default_persona is not None
else False,
user_folders=user_folders or [],
user_files=user_files or [],
labels=labels or [],
)
db_session.add(new_persona)

View File

@@ -62,6 +62,9 @@ def create_search_settings(
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
@@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)
@@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)

View File

@@ -0,0 +1,466 @@
import datetime
import time
from typing import List
from uuid import UUID
from fastapi import UploadFile
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.users import get_current_tenant_id
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.connector import create_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.credentials import create_credential
from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.server.documents.connector import trigger_indexing_for_cc_pair
from onyx.server.documents.connector import upload_files
from onyx.server.documents.models import ConnectorBase
from onyx.server.documents.models import CredentialBase
from onyx.server.models import StatusResponse
USER_FILE_CONSTANT = "USER_FILE_CONNECTOR"
def create_user_files(
files: List[UploadFile],
folder_id: int | None,
user: User | None,
db_session: Session,
link_url: str | None = None,
) -> list[UserFile]:
upload_response = upload_files(files, db_session)
user_files = []
for file_path, file in zip(upload_response.file_paths, files):
new_file = UserFile(
user_id=user.id if user else None,
folder_id=folder_id,
file_id=file_path,
document_id="USER_FILE_CONNECTOR__" + file_path,
name=file.filename,
token_count=None,
link_url=link_url,
)
db_session.add(new_file)
user_files.append(new_file)
db_session.commit()
return user_files
def create_user_file_with_indexing(
files: List[UploadFile],
folder_id: int | None,
user: User,
db_session: Session,
trigger_index: bool = True,
) -> list[UserFile]:
"""Create user files and trigger immediate indexing"""
# Create the user files first
user_files = create_user_files(files, folder_id, user, db_session)
# Create connector and credential for each file
for user_file in user_files:
cc_pair = create_file_connector_credential(user_file, user, db_session)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
# Trigger immediate high-priority indexing for all created files
if trigger_index:
tenant_id = get_current_tenant_id()
for user_file in user_files:
# Use the existing trigger_indexing_for_cc_pair function but with highest priority
if user_file.cc_pair_id:
trigger_indexing_for_cc_pair(
[],
user_file.cc_pair.connector_id,
False,
tenant_id,
db_session,
is_user_file=True,
)
return user_files
def create_file_connector_credential(
user_file: UserFile, user: User, db_session: Session
) -> StatusResponse:
"""Create connector and credential for a user file"""
connector_base = ConnectorBase(
name=f"UserFile-{user_file.file_id}-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(db_session=db_session, connector_data=connector_base)
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{user_file.file_id}-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
return add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{user_file.file_id}-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
is_user_file=True,
)
def get_user_file_indexing_status(
file_ids: list[int], db_session: Session
) -> dict[int, bool]:
"""Get indexing status for multiple user files"""
status_dict = {}
# Query UserFile with cc_pair join
files_with_pairs = (
db_session.query(UserFile)
.filter(UserFile.id.in_(file_ids))
.options(joinedload(UserFile.cc_pair))
.all()
)
for file in files_with_pairs:
if file.cc_pair and file.cc_pair.last_successful_index_time:
status_dict[file.id] = True
else:
status_dict[file.id] = False
return status_dict
def calculate_user_files_token_count(
file_ids: list[int], folder_ids: list[int], db_session: Session
) -> int:
"""Calculate total token count for specified files and folders"""
total_tokens = 0
# Get tokens from individual files
if file_ids:
file_tokens = (
db_session.query(func.sum(UserFile.token_count))
.filter(UserFile.id.in_(file_ids))
.scalar()
or 0
)
total_tokens += file_tokens
# Get tokens from folders
if folder_ids:
folder_files_tokens = (
db_session.query(func.sum(UserFile.token_count))
.filter(UserFile.folder_id.in_(folder_ids))
.scalar()
or 0
)
total_tokens += folder_files_tokens
return total_tokens
def load_all_user_files(
file_ids: list[int], folder_ids: list[int], db_session: Session
) -> list[UserFile]:
"""Load all user files from specified file IDs and folder IDs"""
result = []
# Get individual files
if file_ids:
files = db_session.query(UserFile).filter(UserFile.id.in_(file_ids)).all()
result.extend(files)
# Get files from folders
if folder_ids:
folder_files = (
db_session.query(UserFile).filter(UserFile.folder_id.in_(folder_ids)).all()
)
result.extend(folder_files)
return result
def get_user_files_from_folder(folder_id: int, db_session: Session) -> list[UserFile]:
return db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
def share_file_with_assistant(
file_id: int, assistant_id: int, db_session: Session
) -> None:
file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
if file and assistant:
file.assistants.append(assistant)
db_session.commit()
def unshare_file_with_assistant(
file_id: int, assistant_id: int, db_session: Session
) -> None:
db_session.query(Persona__UserFile).filter(
and_(
Persona__UserFile.user_file_id == file_id,
Persona__UserFile.persona_id == assistant_id,
)
).delete()
db_session.commit()
def share_folder_with_assistant(
folder_id: int, assistant_id: int, db_session: Session
) -> None:
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
if folder and assistant:
for file in folder.files:
share_file_with_assistant(file.id, assistant_id, db_session)
def unshare_folder_with_assistant(
folder_id: int, assistant_id: int, db_session: Session
) -> None:
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
if folder:
for file in folder.files:
unshare_file_with_assistant(file.id, assistant_id, db_session)
def fetch_user_files_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, int | None]:
"""
Fetches user file IDs for the given document IDs.
Args:
document_ids: List of document IDs to fetch user files for
db_session: Database session
Returns:
Dictionary mapping document IDs to user file IDs (or None if no user file exists)
"""
# First, get the document to cc_pair mapping
doc_cc_pairs = (
db_session.query(Document.id, ConnectorCredentialPair.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.filter(Document.id.in_(document_ids))
.all()
)
# Get cc_pair to user_file mapping
cc_pair_to_user_file = (
db_session.query(ConnectorCredentialPair.id, UserFile.id)
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
.filter(
ConnectorCredentialPair.id.in_(
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
)
)
.all()
)
# Create mapping from cc_pair_id to user_file_id
cc_pair_to_user_file_dict = {
cc_pair_id: user_file_id for cc_pair_id, user_file_id in cc_pair_to_user_file
}
# Create the final result mapping document_id to user_file_id
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
for doc_id, cc_pair_id in doc_cc_pairs:
if cc_pair_id in cc_pair_to_user_file_dict:
result[doc_id] = cc_pair_to_user_file_dict[cc_pair_id]
return result
def fetch_user_folders_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, int | None]:
"""
Fetches user folder IDs for the given document IDs.
For each document, returns the folder ID that the document's associated user file belongs to.
Args:
document_ids: List of document IDs to fetch user folders for
db_session: Database session
Returns:
Dictionary mapping document IDs to user folder IDs (or None if no user folder exists)
"""
# First, get the document to cc_pair mapping
doc_cc_pairs = (
db_session.query(Document.id, ConnectorCredentialPair.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.filter(Document.id.in_(document_ids))
.all()
)
# Get cc_pair to user_file and folder mapping
cc_pair_to_folder = (
db_session.query(ConnectorCredentialPair.id, UserFile.folder_id)
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
.filter(
ConnectorCredentialPair.id.in_(
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
)
)
.all()
)
# Create mapping from cc_pair_id to folder_id
cc_pair_to_folder_dict = {
cc_pair_id: folder_id for cc_pair_id, folder_id in cc_pair_to_folder
}
# Create the final result mapping document_id to folder_id
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
for doc_id, cc_pair_id in doc_cc_pairs:
if cc_pair_id in cc_pair_to_folder_dict:
result[doc_id] = cc_pair_to_folder_dict[cc_pair_id]
return result
def get_user_file_from_id(db_session: Session, user_file_id: int) -> UserFile | None:
return db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
# def fetch_user_files_for_documents(
# # document_ids: list[str],
# # db_session: Session,
# # ) -> dict[str, int | None]:
# # # Query UserFile objects for the given document_ids
# # user_files = (
# # db_session.query(UserFile).filter(UserFile.document_id.in_(document_ids)).all()
# # )
# # # Create a dictionary mapping document_ids to UserFile objects
# # result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
# # for user_file in user_files:
# # result[user_file.document_id] = user_file.id
# # return result
def upsert_user_folder(
db_session: Session,
id: int | None = None,
user_id: UUID | None = None,
name: str | None = None,
description: str | None = None,
created_at: datetime.datetime | None = None,
user: User | None = None,
files: list[UserFile] | None = None,
assistants: list[Persona] | None = None,
) -> UserFolder:
if id is not None:
user_folder = db_session.query(UserFolder).filter_by(id=id).first()
else:
user_folder = (
db_session.query(UserFolder).filter_by(name=name, user_id=user_id).first()
)
if user_folder:
if user_id is not None:
user_folder.user_id = user_id
if name is not None:
user_folder.name = name
if description is not None:
user_folder.description = description
if created_at is not None:
user_folder.created_at = created_at
if user is not None:
user_folder.user = user
if files is not None:
user_folder.files = files
if assistants is not None:
user_folder.assistants = assistants
else:
user_folder = UserFolder(
id=id,
user_id=user_id,
name=name,
description=description,
created_at=created_at or datetime.datetime.utcnow(),
user=user,
files=files or [],
assistants=assistants or [],
)
db_session.add(user_folder)
db_session.flush()
return user_folder
def get_user_folder_by_name(db_session: Session, name: str) -> UserFolder | None:
return db_session.query(UserFolder).filter(UserFolder.name == name).first()
def update_user_file_token_count__no_commit(
user_file_id_to_token_count: dict[int, int | None],
db_session: Session,
) -> None:
for user_file_id, token_count in user_file_id_to_token_count.items():
db_session.query(UserFile).filter(UserFile.id == user_file_id).update(
{UserFile.token_count: token_count}
)

View File

@@ -104,6 +104,16 @@ class VespaDocumentFields:
aggregated_chunk_boost_factor: float | None = None
@dataclass
class VespaDocumentUserFields:
"""
Fields that are specific to the user who is indexing the document.
"""
user_file_id: str | None = None
user_folder_id: str | None = None
@dataclass
class UpdateRequest:
"""
@@ -258,7 +268,8 @@ class Updatable(abc.ABC):
*,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> int:
"""
Updates all chunks for a document with the specified fields.

View File

@@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field chunk_context type string {
indexing: summary | attribute
}
field doc_summary type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}
@@ -114,12 +120,22 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
field document_sets type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field user_file type int {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field user_folder type int {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
# also sometimes context for contextual rag
source_links=source_links_dict or {0: ""},
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
@@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk(
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
doc_summary=fields.get(DOC_SUMMARY, ""),
chunk_context=fields.get(CHUNK_CONTEXT, ""),
match_highlights=match_highlights,
updated_at=updated_at,
)
@@ -345,6 +350,19 @@ def query_vespa(
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
try:
num_retrieved_inference_chunks = len(inference_chunks)
num_retrieved_document_ids = len(
set([chunk.document_id for chunk in inference_chunks])
)
logger.debug(
f"Retrieved {num_retrieved_inference_chunks} inference chunks for {num_retrieved_document_ids} documents"
)
except Exception as e:
# Debug logging only, should not fail the retrieval
logger.error(f"Error logging retrieval statistics: {e}")
# Good Debugging Spot
return inference_chunks

View File

@@ -36,6 +36,7 @@ from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
from onyx.document_index.interfaces import UpdateRequest
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
from onyx.document_index.vespa.chunk_retrieval import (
parallel_visit_api_retrieval,
@@ -70,6 +71,8 @@ from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from onyx.document_index.vespa_constants import TENANT_ID_PAT
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
@@ -187,7 +190,7 @@ class VespaIndex(DocumentIndex):
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
@@ -592,7 +595,8 @@ class VespaIndex(DocumentIndex):
self,
doc_chunk_id: UUID,
index_name: str,
fields: VespaDocumentFields,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
doc_id: str,
http_client: httpx.Client,
) -> None:
@@ -603,21 +607,31 @@ class VespaIndex(DocumentIndex):
update_dict: dict[str, dict] = {"fields": {}}
if fields.boost is not None:
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields is not None:
if fields.boost is not None:
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if fields.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if fields.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
if fields.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
if fields.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
if fields.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
if user_fields is not None:
if user_fields.user_file_id is not None:
update_dict["fields"][USER_FILE] = {"assign": user_fields.user_file_id}
if user_fields.user_folder_id is not None:
update_dict["fields"][USER_FOLDER] = {
"assign": user_fields.user_folder_id
}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update.")
@@ -649,7 +663,8 @@ class VespaIndex(DocumentIndex):
*,
chunk_count: int | None,
tenant_id: str,
fields: VespaDocumentFields,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> int:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
@@ -682,7 +697,12 @@ class VespaIndex(DocumentIndex):
for doc_chunk_id in doc_chunk_ids:
self._update_single_chunk(
doc_chunk_id, index_name, fields, doc_id, httpx_client
doc_chunk_id,
index_name,
fields,
user_fields,
doc_id,
httpx_client,
)
return doc_chunk_count
@@ -723,6 +743,7 @@ class VespaIndex(DocumentIndex):
tenant_id=tenant_id,
large_chunks_enabled=large_chunks_enabled,
)
for doc_chunk_ids_batch in batch_generator(
chunks_to_delete, BATCH_SIZE
):

View File

@@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@@ -49,6 +51,8 @@ from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import TITLE
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.utils.logger import setup_logger
@@ -174,7 +178,7 @@ def _index_vespa_chunk(
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
@@ -189,6 +193,8 @@ def _index_vespa_chunk(
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
CHUNK_CONTEXT: chunk.chunk_context,
DOC_SUMMARY: chunk.doc_summary,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
@@ -201,6 +207,8 @@ def _index_vespa_chunk(
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
USER_FILE: chunk.user_file if chunk.user_file is not None else None,
USER_FOLDER: chunk.user_folder if chunk.user_folder is not None else None,
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -5,7 +5,6 @@ from datetime import timezone
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
@@ -14,6 +13,8 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -27,14 +28,26 @@ def build_vespa_filters(
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
if not key or not vals:
return ""
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
if not eq_elems:
return ""
or_clause = " or ".join(eq_elems)
return f"({or_clause}) and "
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
"""
For an integer field filter.
If vals is not None, we want *only* docs whose key matches one of vals.
"""
# If `vals` is None => skip the filter entirely
if vals is None or not vals:
return ""
valid_vals = [val for val in vals if val]
if not key or not valid_vals:
return ""
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
# Otherwise build the OR filter
eq_elems = [f"{key} = {val}" for val in vals]
or_clause = " or ".join(eq_elems)
result = f"({or_clause}) and "
@@ -42,53 +55,57 @@ def build_vespa_filters(
def _build_time_filter(
cutoff: datetime | None,
# Slightly over 3 Months, approximately 1 fiscal quarter
untimed_doc_cutoff: timedelta = timedelta(days=92),
) -> str:
if not cutoff:
return ""
# For Documents that don't have an updated at, filter them out for queries asking for
# very recent documents (3 months) default. Documents that don't have an updated at
# time are assigned 3 months for time decay value
include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff
cutoff_secs = int(cutoff.timestamp())
if include_untimed:
# Documents without updated_at are assigned -1 as their date
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
# If running in multi-tenant mode, we may want to filter by tenant_id
# If running in multi-tenant mode
if filters.tenant_id and MULTI_TENANT:
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
if filters.access_control_list is not None:
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
)
# ACL filters
# if filters.access_control_list is not None:
# filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
# Source type filters
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None
)
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
# Tag filters
tag_attributes = None
tags = filters.tags
if tags:
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
if filters.tags:
# build e.g. "tag_key|tag_value"
tag_attributes = [
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
]
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
# Document sets
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
# New: user_file_ids as integer filters
filter_str += _build_int_or_filters(USER_FILE, filters.user_file_ids)
filter_str += _build_int_or_filters(USER_FOLDER, filters.user_folder_ids)
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5] # We remove the trailing " and "
filter_str = filter_str[:-5]
return filter_str

View File

@@ -67,10 +67,14 @@ EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
DOC_SUMMARY = "doc_summary"
CHUNK_CONTEXT = "chunk_context"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
@@ -106,6 +110,8 @@ YQL_BASE = (
f"{LARGE_CHUNK_REFERENCE_IDS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{DOC_SUMMARY}, "
f"{CHUNK_CONTEXT}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)

View File

@@ -7,6 +7,8 @@ from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
@@ -35,7 +37,7 @@ logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
PLAIN_TEXT_FILE_EXTENSIONS = [
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
]
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_text_file(file: IO[bytes]) -> bool:
@@ -382,6 +412,9 @@ def extract_file_text(
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
handle (such as images).
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -405,7 +438,9 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if is_valid_file_ext(extension):
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)

View File

@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]

View File

@@ -37,6 +37,7 @@ def delete_unstructured_api_key() -> None:
def _sdk_partition_request(
file: IO[Any], file_name: str, **kwargs: Any
) -> operations.PartitionRequest:
file.seek(0, 0)
try:
request = operations.PartitionRequest(
partition_parameters=shared.PartitionParameters(

View File

@@ -31,6 +31,7 @@ class FileStore(ABC):
file_origin: FileOrigin,
file_type: str,
file_metadata: dict | None = None,
commit: bool = True,
) -> None:
"""
Save a file to the blob store
@@ -42,6 +43,8 @@ class FileStore(ABC):
- display_name: Display name of the file
- file_origin: Origin of the file
- file_type: Type of the file
- file_metadata: Additional metadata for the file
- commit: Whether to commit the transaction after saving the file
"""
raise NotImplementedError
@@ -90,6 +93,7 @@ class PostgresBackedFileStore(FileStore):
file_origin: FileOrigin,
file_type: str,
file_metadata: dict | None = None,
commit: bool = True,
) -> None:
try:
# The large objects in postgres are saved as special objects can be listed with
@@ -104,7 +108,8 @@ class PostgresBackedFileStore(FileStore):
db_session=self.db_session,
file_metadata=file_metadata,
)
self.db_session.commit()
if commit:
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise

View File

@@ -14,6 +14,7 @@ class ChatFileType(str, Enum):
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
CSV = "csv"
USER_KNOWLEDGE = "user_knowledge"
class FileDescriptor(TypedDict):

View File

@@ -10,12 +10,62 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.utils.b64 import get_image_type
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def user_file_id_to_plaintext_file_name(user_file_id: int) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def store_user_file_plaintext(
user_file_id: int, plaintext_content: str, db_session: Session
) -> bool:
"""
Store plaintext content for a user file in the file store.
Args:
user_file_id: The ID of the user file
plaintext_content: The plaintext content to store
db_session: The database session
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
# Store the plaintext in the file store
file_store = get_default_file_store(db_session)
file_content = BytesIO(plaintext_content.encode("utf-8"))
try:
file_store.save_file(
file_name=plaintext_file_name,
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
commit=False,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
return False
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -53,6 +103,83 @@ def load_all_chat_files(
return files
def load_user_folder(folder_id: int, db_session: Session) -> list[InMemoryChatFile]:
user_files = (
db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
)
return [load_user_file(file.id, db_session) for file in user_files]
def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
user_file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
if not user_file:
raise ValueError(f"User file with id {file_id} not found")
# Try to load plaintext version first
file_store = get_default_file_store(db_session)
plaintext_file_name = user_file_id_to_plaintext_file_name(file_id)
try:
file_io = file_store.read_file(plaintext_file_name, mode="b")
return InMemoryChatFile(
file_id=str(user_file.file_id),
content=file_io.read(),
file_type=ChatFileType.USER_KNOWLEDGE,
filename=user_file.name,
)
except Exception as e:
logger.warning(
f"Failed to load plaintext file {plaintext_file_name}, defaulting to original file: {e}"
)
# Fall back to original file if plaintext not available
file_io = file_store.read_file(user_file.file_id, mode="b")
return InMemoryChatFile(
file_id=str(user_file.file_id),
content=file_io.read(),
file_type=ChatFileType.USER_KNOWLEDGE,
filename=user_file.name,
)
def load_all_user_files(
user_file_ids: list[int],
user_folder_ids: list[int],
db_session: Session,
) -> list[InMemoryChatFile]:
return cast(
list[InMemoryChatFile],
run_functions_tuples_in_parallel(
[(load_user_file, (file_id, db_session)) for file_id in user_file_ids]
)
+ [
file
for folder_id in user_folder_ids
for file in load_user_folder(folder_id, db_session)
],
)
def load_all_user_file_files(
user_file_ids: list[int],
user_folder_ids: list[int],
db_session: Session,
) -> list[UserFile]:
user_files: list[UserFile] = []
for user_file_id in user_file_ids:
user_file = (
db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
)
if user_file is not None:
user_files.append(user_file)
for user_folder_id in user_folder_ids:
user_files.extend(
db_session.query(UserFile)
.filter(UserFile.folder_id == user_folder_id)
.all()
)
return user_files
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
@@ -71,6 +198,7 @@ def save_file_from_url(url: str) -> str:
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png;base64",
commit=True,
)
return unique_id
@@ -85,6 +213,7 @@ def save_file_from_base64(base64_string: str) -> str:
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
commit=True,
)
return unique_id
@@ -128,3 +257,39 @@ def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
]
return run_functions_tuples_in_parallel(funcs)
def load_all_persona_files_for_chat(
persona_id: int, db_session: Session
) -> tuple[list[InMemoryChatFile], list[int]]:
from onyx.db.models import Persona
from sqlalchemy.orm import joinedload
persona = (
db_session.query(Persona)
.filter(Persona.id == persona_id)
.options(
joinedload(Persona.user_files),
joinedload(Persona.user_folders).joinedload(UserFolder.files),
)
.one()
)
persona_file_calls = [
(load_user_file, (user_file.id, db_session)) for user_file in persona.user_files
]
persona_loaded_files = run_functions_tuples_in_parallel(persona_file_calls)
persona_folder_files = []
persona_folder_file_ids = []
for user_folder in persona.user_folders:
folder_files = load_user_folder(user_folder.id, db_session)
persona_folder_files.extend(folder_files)
persona_folder_file_ids.extend([file.id for file in user_folder.files])
persona_files = list(persona_loaded_files) + persona_folder_files
persona_file_ids = [
file.id for file in persona.user_files
] + persona_folder_file_ids
return persona_files, persona_file_ids

View File

@@ -1,7 +1,10 @@
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import MINI_CHUNK_SIZE
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.constants import SECTION_SEPARATOR
@@ -13,6 +16,7 @@ from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
@@ -82,6 +86,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
mini_chunk_texts=None,
large_chunk_id=large_chunk_id,
chunk_context="",
doc_summary="",
contextual_rag_reserved_tokens=0,
)
offset = 0
@@ -120,6 +127,7 @@ class Chunker:
tokenizer: BaseTokenizer,
enable_multipass: bool = False,
enable_large_chunks: bool = False,
enable_contextual_rag: bool = False,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
@@ -133,9 +141,20 @@ class Chunker:
self.chunk_token_limit = chunk_token_limit
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.enable_contextual_rag = enable_contextual_rag
if enable_contextual_rag:
assert (
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * (
int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY)
)
self.tokenizer = tokenizer
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=blurb_size,
@@ -221,30 +240,12 @@ class Chunker:
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
)
chunks_list.append(new_chunk)
def _chunk_document(
self,
document: IndexingDocument,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Legacy method for backward compatibility.
Calls _chunk_document_with_sections with document.sections.
"""
return self._chunk_document_with_sections(
document,
document.processed_sections,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
@@ -264,7 +265,7 @@ class Chunker:
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(section.text or "")
section_text = clean_text(str(section.text or ""))
section_link_text = section.link or ""
image_url = section.image_file_name
@@ -309,7 +310,7 @@ class Chunker:
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.tokenize(section_text))
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
@@ -332,8 +333,7 @@ class Chunker:
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.tokenize(split_text))
> content_token_limit
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
@@ -363,10 +363,10 @@ class Chunker:
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.tokenize(chunk_text))
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
@@ -414,7 +414,7 @@ class Chunker:
# Title prep
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.tokenize(title_prefix))
title_tokens = len(self.tokenizer.encode(title_prefix))
# Metadata prep
metadata_suffix_semantic = ""
@@ -427,15 +427,50 @@ class Chunker:
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
# If metadata is too large, skip it in the semantic content
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
metadata_suffix_semantic = ""
metadata_tokens = 0
single_chunk_fits = True
doc_token_count = 0
if self.enable_contextual_rag:
doc_content = document.get_text_content()
tokenized_doc = self.tokenizer.tokenize(doc_content)
doc_token_count = len(tokenized_doc)
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
single_chunk_fits = (
doc_token_count + title_tokens + metadata_tokens
<= self.chunk_token_limit
)
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
context_size = 0
if (
self.enable_contextual_rag
and not single_chunk_fits
and not AVERAGE_SUMMARY_EMBEDDINGS
):
context_size += self.default_contextual_rag_reserved_tokens
# Adjust content token limit to accommodate title + metadata
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
)
# first check: if there is not enough actual chunk content when including contextual rag,
# then don't do contextual rag
if content_token_limit <= CHUNK_MIN_CONTENT:
context_size = 0 # Don't do contextual RAG
# revert to previous content token limit
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens
)
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
# Not enough space left, so revert to full chunk without the prefix
content_token_limit = self.chunk_token_limit
@@ -459,6 +494,9 @@ class Chunker:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
for chunk in normal_chunks:
chunk.contextual_rag_reserved_tokens = context_size
return normal_chunks
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:

View File

@@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
if chunk.large_chunk_reference_ids:
large_chunks_present = True
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Callable
from functools import partial
from typing import Protocol
@@ -8,7 +9,13 @@ from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
@@ -36,11 +43,15 @@ from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_active_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.db.user_documents import fetch_user_files_for_documents
from onyx.db.user_documents import fetch_user_folders_for_documents
from onyx.db.user_documents import update_user_file_token_count__no_commit
from onyx.document_index.document_index_utils import (
get_multipass_config,
)
@@ -48,6 +59,7 @@ from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
@@ -57,11 +69,25 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.llm.interfaces import LLM
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
@@ -249,6 +275,8 @@ def index_doc_batch_with_handler(
db_session: Session,
tenant_id: str,
ignore_time_skip: bool = False,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
) -> IndexingPipelineResult:
try:
index_pipeline_result = index_doc_batch(
@@ -261,6 +289,8 @@ def index_doc_batch_with_handler(
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)
except Exception as e:
# don't log the batch directly, it's too much text
@@ -439,7 +469,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
**document.dict(),
processed_sections=[
Section(
text=section.text if isinstance(section, TextSection) else None,
text=section.text if isinstance(section, TextSection) else "",
link=section.link,
image_file_name=section.image_file_name
if isinstance(section, ImageSection)
@@ -459,11 +489,11 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
for section in document.sections:
# For ImageSection, process and create base Section with both text and image_file_name
if isinstance(section, ImageSection):
# Default section with image path preserved
# Default section with image path preserved - ensure text is always a string
processed_section = Section(
link=section.link,
image_file_name=section.image_file_name,
text=None, # Will be populated if summarization succeeds
text="", # Initialize with empty string
)
# Try to get image summary
@@ -506,13 +536,21 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
text=section.text, link=section.link, image_file_name=None
text=section.text or "", # Ensure text is always a string, not None
link=section.link,
image_file_name=None,
)
processed_sections.append(processed_section)
# If it's already a base Section (unlikely), just append it
# If it's already a base Section (unlikely), just append it with text validation
else:
processed_sections.append(section)
# Ensure text is always a string
processed_section = Section(
text=section.text if section.text is not None else "",
link=section.link,
image_file_name=section.image_file_name,
)
processed_sections.append(processed_section)
# Create IndexingDocument with original sections and processed_sections
indexed_document = IndexingDocument(
@@ -523,6 +561,145 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
return indexed_documents
def add_document_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_tokens: int,
) -> list[int] | None:
"""
Adds a document summary to a list of chunks from the same document.
Returns the number of tokens in the document.
"""
doc_tokens = []
# this is value is the same for each chunk in the document; 0 indicates
# There is not enough space for contextual RAG (the chunk content
# and possibly metadata took up too much space)
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return None
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
doc_summary = message_to_string(
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
)
for chunk in chunks_by_doc:
chunk.doc_summary = doc_summary
return doc_tokens
def add_chunk_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_chunk_tokens: int,
doc_tokens: list[int] | None,
) -> None:
"""
Adds chunk summaries to the chunks grouped by document id.
Chunk summaries look at the chunk as well as the entire document (or a summary,
if the document is too long) and describe how the chunk relates to the document.
"""
# all chunks within a document have the same contextual_rag_reserved_tokens
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return
# use values computed in above doc summary section if available
doc_tokens = doc_tokens or tokenizer.encode(
chunks_by_doc[0].source_document.get_text_content()
)
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer)
# only compute doc summary if needed
doc_info = (
doc_content
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
else chunks_by_doc[0].doc_summary
)
if not doc_info:
# This happens if the document is too long AND document summaries are turned off
# In this case we compute a doc summary using the LLM
doc_info = message_to_string(
llm.invoke(
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
max_tokens=MAX_CONTEXT_TOKENS,
)
)
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
def assign_context(chunk: DocAwareChunk) -> None:
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
try:
chunk.chunk_context = message_to_string(
llm.invoke(
context_prompt1 + context_prompt2,
max_tokens=MAX_CONTEXT_TOKENS,
)
)
except LLMRateLimitError as e:
# Erroring during chunker is undesirable, so we log the error and continue
# TODO: for v2, add robust retry logic
logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
except Exception as e:
logger.exception(f"Error adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
run_functions_tuples_in_parallel(
[(assign_context, (chunk,)) for chunk in chunks_by_doc]
)
def add_contextual_summaries(
chunks: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
chunk_token_limit: int,
) -> list[DocAwareChunk]:
"""
Adds Document summary and chunk-within-document context to the chunks
based on which environment variables are set.
"""
max_context = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
output_tokens=MAX_CONTEXT_TOKENS,
)
doc2chunks = defaultdict(list)
for chunk in chunks:
doc2chunks[chunk.source_document.id].append(chunk)
# The number of tokens allowed for the document when computing a document summary
trunc_doc_summary_tokens = max_context - len(
tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
)
prompt_tokens = len(
tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
)
# The number of tokens allowed for the document when computing a
# "chunk in context of document" summary
trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit
for chunks_by_doc in doc2chunks.values():
doc_tokens = None
if USE_DOCUMENT_SUMMARY:
doc_tokens = add_document_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens
)
if USE_CHUNK_SUMMARY:
add_chunk_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens
)
return chunks
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -534,6 +711,8 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
tenant_id: str,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
ignore_time_skip: bool = False,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> IndexingPipelineResult:
@@ -595,6 +774,21 @@ def index_doc_batch(
# NOTE: no special handling for failures here, since the chunker is not
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
llm_tokenizer: BaseTokenizer | None = None
# contextual RAG
if enable_contextual_rag:
assert llm is not None, "must provide an LLM for contextual RAG"
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Because the chunker's tokens are different from the LLM's tokens,
# We add a fudge factor to ensure we truncate prompts to the LLM's token limit
chunks = add_contextual_summaries(
chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
@@ -638,6 +832,15 @@ def index_doc_batch(
)
}
doc_id_to_user_file_id: dict[str, int | None] = fetch_user_files_for_documents(
document_ids=updatable_ids, db_session=db_session
)
doc_id_to_user_folder_id: dict[
str, int | None
] = fetch_user_folders_for_documents(
document_ids=updatable_ids, db_session=db_session
)
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
@@ -657,6 +860,48 @@ def index_doc_batch(
for document_id in updatable_ids
}
try:
llm, _ = get_default_llms()
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
except Exception as e:
logger.error(f"Error getting tokenizer: {e}")
llm_tokenizer = None
# Calculate token counts for each document by combining all its chunks' content
user_file_id_to_token_count: dict[int, int | None] = {}
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
@@ -669,6 +914,10 @@ def index_doc_batch(
document_sets=set(
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_file=doc_id_to_user_file_id.get(chunk.source_document.id, None),
user_folder=doc_id_to_user_folder_id.get(
chunk.source_document.id, None
),
boost=(
ctx.id_to_db_doc_map[chunk.source_document.id].boost
if chunk.source_document.id in ctx.id_to_db_doc_map
@@ -750,6 +999,11 @@ def index_doc_batch(
db_session=db_session,
)
update_user_file_token_count__no_commit(
user_file_id_to_token_count=user_file_id_to_token_count,
db_session=db_session,
)
# these documents can now be counted as part of the CC Pairs
# document count, so we need to mark them as indexed
# NOTE: even documents we skipped since they were already up
@@ -761,12 +1015,22 @@ def index_doc_batch(
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
# Store the plaintext in the file store for faster retrieval
for user_file_id, raw_text in user_file_id_to_raw_text.items():
# Use the dedicated function to store plaintext
store_user_file_plaintext(
user_file_id=user_file_id,
plaintext_content=raw_text,
db_session=db_session,
)
# save the chunk boost components to postgres
update_chunk_boost_components__no_commit(
chunk_data=updatable_chunk_data, db_session=db_session
)
# Pause user file ccpairs
db_session.commit()
result = IndexingPipelineResult(
@@ -791,13 +1055,33 @@ def build_indexing_pipeline(
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
all_search_settings = get_active_search_settings(db_session)
if (
all_search_settings.secondary
and all_search_settings.secondary.status == IndexModelStatus.FUTURE
):
search_settings = all_search_settings.secondary
else:
search_settings = all_search_settings.primary
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
)
llm = None
if enable_contextual_rag:
llm = get_llm_for_contextual_rag(
search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME,
search_settings.contextual_rag_llm_provider
or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER,
)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
)
@@ -811,4 +1095,6 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
db_session=db_session,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)

View File

@@ -49,6 +49,15 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str
metadata_suffix_keyword: str
# This is the number of tokens reserved for contextual RAG
# in the chunk. doc_summary and chunk_context conbined should
# contain at most this many tokens.
contextual_rag_reserved_tokens: int
# This is the summary for the document generated for contextual RAG
doc_summary: str
# This is the context for this chunk generated for contextual RAG
chunk_context: str
mini_chunk_texts: list[str] | None
large_chunk_id: int | None
@@ -91,6 +100,8 @@ class DocMetadataAwareIndexChunk(IndexChunk):
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
user_file: int | None
user_folder: int | None
boost: int
aggregated_chunk_boost_factor: float
@@ -100,6 +111,8 @@ class DocMetadataAwareIndexChunk(IndexChunk):
index_chunk: IndexChunk,
access: "DocumentAccess",
document_sets: set[str],
user_file: int | None,
user_folder: int | None,
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
@@ -109,6 +122,8 @@ class DocMetadataAwareIndexChunk(IndexChunk):
**index_chunk_data,
access=access,
document_sets=document_sets,
user_file=user_file,
user_folder=user_folder,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,
@@ -154,6 +169,9 @@ class IndexingSetting(EmbeddingModelDetail):
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
enable_contextual_rag: bool
contextual_rag_llm_name: str | None = None
contextual_rag_llm_provider: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@@ -178,6 +196,7 @@ class IndexingSetting(EmbeddingModelDetail):
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
)

View File

@@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM):
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
return

View File

@@ -16,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import LLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
@@ -154,6 +155,40 @@ def get_default_llm_with_vision(
return None
def llm_from_provider(
model_name: str,
llm_provider: LLMProvider,
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model_name,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
with get_session_context_manager() as db_session:
llm_provider = fetch_llm_provider_view(db_session, model_provider)
if not llm_provider:
raise ValueError("No LLM provider with name {} found".format(model_provider))
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
)
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
@@ -179,14 +214,9 @@ def get_default_llms(
raise ValueError("No fast default model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
return llm_from_provider(
model_name=model,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,

View File

@@ -1,4 +1,5 @@
import copy
import io
import json
from collections.abc import Callable
from collections.abc import Iterator
@@ -29,13 +30,20 @@ from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
@@ -44,6 +52,10 @@ from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
MAX_CONTEXT_TOKENS = 100
ONE_MILLION = 1_000_000
CHUNKS_PER_DOC_ESTIMATE = 5
def litellm_exception_to_error_msg(
e: Exception,
@@ -119,7 +131,12 @@ def _build_content(
text_files = [
file
for file in files
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
if file.file_type
in (
ChatFileType.PLAIN_TEXT,
ChatFileType.CSV,
ChatFileType.USER_KNOWLEDGE,
)
]
if not text_files:
@@ -127,7 +144,18 @@ def _build_content(
final_message_with_files = "FILES:\n\n"
for file in text_files:
file_content = file.content.decode("utf-8")
try:
file_content = file.content.decode("utf-8")
except UnicodeDecodeError:
# Try to decode as binary
try:
file_content, _, _ = read_pdf_file(io.BytesIO(file.content))
except Exception:
file_content = f"[Binary file content - {file.file_type} format]"
logger.exception(
f"Could not decode binary file content for file type: {file.file_type}"
)
# logger.warning(f"Could not decode binary file content for file type: {file.file_type}")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
@@ -155,7 +183,6 @@ def build_content_with_imgs(
img_urls = img_urls or []
b64_imgs = b64_imgs or []
message_main_content = _build_content(message, files)
if exclude_images or (not img_files and not img_urls):
@@ -403,19 +430,83 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_contextual_cost(
llm: LLM,
) -> float:
"""
Approximate the cost of using the given LLM for indexing with Contextual RAG.
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
and we assume that every chunk is maximized in terms of content and context.
We also assume that every document is maximized in terms of content, as currently if
a document is longer than a certain length, its summary is used instead of the full content.
We expect that the first assumption will overestimate more than the second one
underestimates, so this should be a fairly conservative price estimate. Also,
this does not account for the cost of documents that fit within a single chunk
which do not get contextualized.
"""
# calculate input costs
num_tokens = ONE_MILLION
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
# on average.
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
num_input_tokens = 0
num_output_tokens = 0
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
return 0
if USE_CHUNK_SUMMARY:
# Each per-chunk prompt includes:
# - The prompt tokens
# - the document tokens
# - the chunk tokens
# for each chunk, we prompt the LLM with the contextual RAG prompt
# and the full document content (or the doc summary, so this is an overestimate)
num_input_tokens += num_input_chunks * (
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
)
# in aggregate, each chunk content is used as a prompt input once
# so the full input size is covered
num_input_tokens += num_tokens
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
# going over each doc once means all the tokens, plus the prompt tokens for
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
# So, we include this unconditionally to overestimate.
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
model=llm.config.model_name,
prompt_tokens=num_input_tokens,
completion_tokens=num_output_tokens,
)
# Costs are in USD dollars per million tokens
return usd_per_prompt + usd_per_completion
def get_llm_max_tokens(
model_map: dict,
model_name: str,
@@ -440,14 +531,10 @@ def get_llm_max_tokens(
if "max_input_tokens" in model_obj:
max_tokens = model_obj["max_input_tokens"]
logger.debug(
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
)
return max_tokens
if "max_tokens" in model_obj:
max_tokens = model_obj["max_tokens"]
logger.debug(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
return max_tokens
logger.error(f"No max tokens found for LLM: {model_name}")
@@ -469,21 +556,16 @@ def get_llm_max_output_tokens(
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
logger.debug(f"Using model object for {model_name}")
else:
logger.debug(f"Using model object for {model_provider}/{model_name}")
pass
if "max_output_tokens" in model_obj:
max_output_tokens = model_obj["max_output_tokens"]
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
return max_output_tokens
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
if "max_tokens" in model_obj:
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
logger.info(
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
)
return max_output_tokens
logger.error(f"No max output tokens found for LLM: {model_name}")

View File

@@ -97,6 +97,7 @@ from onyx.server.settings.api import basic_router as settings_router
from onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from onyx.server.user_documents.api import router as user_documents_router
from onyx.server.utils import BasicAuthenticationError
from onyx.setup import setup_multitenant_onyx
from onyx.setup import setup_onyx
@@ -297,6 +298,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, user_documents_router)
include_router_with_global_prefix_prepended(application, folder_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -391,6 +393,11 @@ def get_application() -> FastAPI:
prefix="/auth",
)
if (
AUTH_TYPE == AuthType.CLOUD
or AUTH_TYPE == AuthType.BASIC
or AUTH_TYPE == AuthType.GOOGLE_OAUTH
):
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,

View File

@@ -3,6 +3,8 @@ from abc import ABC
from abc import abstractmethod
from copy import copy
from tokenizers import Encoding # type: ignore
from tokenizers import Tokenizer # type: ignore
from transformers import logging as transformer_logging # type:ignore
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
@@ -11,6 +13,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -67,16 +71,27 @@ class TiktokenTokenizer(BaseTokenizer):
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
from tokenizers import Tokenizer # type: ignore
self.encoder: Tokenizer = Tokenizer.from_pretrained(model_name)
self.encoder = Tokenizer.from_pretrained(model_name)
def _safer_encode(self, string: str) -> Encoding:
"""
Encode a string using the HuggingFaceTokenizer, but if it fails,
encode the string as ASCII and decode it back to a string. This helps
in cases where the string has weird characters like \udeb4.
"""
try:
return self.encoder.encode(string, add_special_tokens=False)
except Exception:
return self.encoder.encode(
string.encode("ascii", "ignore").decode(), add_special_tokens=False
)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode(string, add_special_tokens=False).ids
return self._safer_encode(string).ids
def tokenize(self, string: str) -> list[str]:
return self.encoder.encode(string, add_special_tokens=False).tokens
return self._safer_encode(string).tokens
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
@@ -159,9 +174,26 @@ def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
if len(tokens) <= desired_length:
return content
return tokenizer.decode(tokens[:desired_length])
def tokenizer_trim_middle(
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
) -> str:
if len(tokens) <= desired_length:
return tokenizer.decode(tokens)
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
sep_tokens = tokenizer.encode(sep_str)
slice_size = (desired_length - len(sep_tokens)) // 2
assert slice_size > 0, "Slice size is not positive, desired length is too short"
return (
tokenizer.decode(tokens[:slice_size])
+ sep_str
+ tokenizer.decode(tokens[-slice_size:])
)
def tokenizer_trim_chunks(

View File

@@ -594,7 +594,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
if event_type == "message":
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and bot_tag_id in msg
is_tagged = bot_tag_id and f"<@{bot_tag_id}>" in msg
is_onyx_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
# OnyxBot should never respond to itself
@@ -727,7 +727,11 @@ def build_request_details(
event = cast(dict[str, Any], req.payload["event"])
msg = cast(str, event["text"])
channel = cast(str, event["channel"])
tagged = event.get("type") == "app_mention"
# Check for both app_mention events and messages containing bot tag
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
tagged = (event.get("type") == "app_mention") or (
event.get("type") == "message" and bot_tag_id and f"<@{bot_tag_id}>" in msg
)
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender_id = event.get("user") or None

View File

@@ -145,7 +145,7 @@ def update_emote_react(
def remove_onyx_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_onyx_bot_slack_bot_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
return re.sub(rf"<@{bot_tag_id}>\s*", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:

View File

@@ -220,3 +220,29 @@ Chat History:
Based on the above, what is a short name to convey the topic of the conversation?
""".strip()
# NOTE: the prompt separation is partially done for efficiency; previously I tried
# to do it all in one prompt with sequential format() calls but this will cause a backend
# error when the document contains any {} as python will expect the {} to be filled by
# format() arguments
CONTEXTUAL_RAG_PROMPT1 = """<document>
{document}
</document>
Here is the chunk we want to situate within the whole document"""
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
{chunk}
</chunk>
Please give a short succinct context to situate this chunk within the overall document
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
context and nothing else. """
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
DOCUMENT_SUMMARY_PROMPT = """<document>
{document}
</document>
Please give a short succinct summary of the entire document. Answer only with the succinct
summary and nothing else. """
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29

View File

@@ -195,7 +195,7 @@ class RedisConnectorPermissionSync:
),
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
async_results.append(result)

Some files were not shown because too many files have changed in this diff Show More