mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
4 Commits
testing_li
...
fix_textse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a475f1b328 | ||
|
|
e37bb2209e | ||
|
|
5aba739aee | ||
|
|
906b29bd9b |
209
.github/workflows/pr-mit-integration-tests.yml
vendored
209
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -1,209 +0,0 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
775
.vscode/launch.template.jsonc
vendored
775
.vscode/launch.template.jsonc
vendored
@@ -6,419 +6,396 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": ["--- Individual ---"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
}
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"model_server.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"onyx.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ WORKDIR /app
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
|
||||
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./onyx/__init__.py /app/onyx/__init__.py
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""update prompt length
|
||||
|
||||
Revision ID: 4794bc13e484
|
||||
Revises: f7505c5b0284
|
||||
Create Date: 2025-04-02 11:26:36.180328
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4794bc13e484"
|
||||
down_revision = "f7505c5b0284"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=5000000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=5000000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=5000000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=5000000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
"""duplicated no-harm user file migration
|
||||
|
||||
Revision ID: 6a804aeb4830
|
||||
Revises: 8e1ac4f39a9f
|
||||
Create Date: 2025-04-01 07:26:10.539362
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6a804aeb4830"
|
||||
down_revision = "8e1ac4f39a9f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if user_file table already exists
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
if not inspector.has_table("user_file"):
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,50 +0,0 @@
|
||||
"""enable contextual retrieval
|
||||
|
||||
Revision ID: 8e1ac4f39a9f
|
||||
Revises: 9aadf32dfeb4
|
||||
Create Date: 2024-12-20 13:29:09.918661
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e1ac4f39a9f"
|
||||
down_revision = "9aadf32dfeb4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"enable_contextual_rag",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_provider",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "enable_contextual_rag")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_name")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_provider")
|
||||
@@ -1,113 +0,0 @@
|
||||
"""add user files
|
||||
|
||||
Revision ID: 9aadf32dfeb4
|
||||
Revises: 3781a5eb12cb
|
||||
Create Date: 2025-01-26 16:08:21.551022
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import datetime
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9aadf32dfeb4"
|
||||
down_revision = "3781a5eb12cb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the persona__user_folder table
|
||||
op.drop_table("persona__user_folder")
|
||||
# Drop the persona__user_file table
|
||||
op.drop_table("persona__user_file")
|
||||
# Drop the user_file table
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""add prompt length limit
|
||||
|
||||
Revision ID: f71470ba9274
|
||||
Revises: 6a804aeb4830
|
||||
Create Date: 2025-04-01 15:07:14.977435
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f71470ba9274"
|
||||
down_revision = "6a804aeb4830"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
@@ -1,77 +0,0 @@
|
||||
"""updated constraints for ccpairs
|
||||
|
||||
Revision ID: f7505c5b0284
|
||||
Revises: f71470ba9274
|
||||
Create Date: 2025-04-01 17:50:42.504818
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7505c5b0284"
|
||||
down_revision = "f71470ba9274"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Drop the old foreign-key constraints
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# 2) Re-add them with ondelete='CASCADE'
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="connector",
|
||||
local_cols=["connector_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="credential",
|
||||
local_cols=["credential_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse the changes for rollback
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate without CASCADE
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"connector",
|
||||
["connector_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"credential",
|
||||
["credential_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -159,9 +159,6 @@ def _get_space_permissions(
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
logger.info(
|
||||
f"Found space permissions for space '{space_key}': {space_permissions}"
|
||||
)
|
||||
|
||||
return space_permissions_by_space_key
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
final_chunk_dict: dict[str, InferenceChunk] = {}
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
final_chunk_dict[chunk.unique_id] = chunk
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
@@ -79,16 +79,6 @@ def _post_query_chunk_censoring(
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
for censored_chunk in censored_chunks:
|
||||
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
|
||||
|
||||
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
|
||||
final_chunk_list: list[InferenceChunk] = []
|
||||
for chunk in chunks:
|
||||
# only if the chunk is in the final censored chunks, add it to the final list
|
||||
# if it is missing, that means it was intentionally left out
|
||||
if chunk.unique_id in final_chunk_dict:
|
||||
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
|
||||
|
||||
return final_chunk_list
|
||||
return chunks_to_keep
|
||||
|
||||
@@ -58,7 +58,6 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
logger.warning(f"User '{user_email}' not found in Salesforce")
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
@@ -66,7 +65,6 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
logger.debug(f"Object ID to access: {object_id_to_access}")
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
|
||||
@@ -42,18 +42,11 @@ def get_any_salesforce_client_for_doc_id(
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
# try emails
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
return None
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
|
||||
@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) The anonymous user cookie
|
||||
3) Reset token cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
@@ -52,55 +52,41 @@ async def _get_tenant_id_from_request(
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if token_data:
|
||||
tenant_id_from_payload = token_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else None
|
||||
)
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if tenant_id and not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(
|
||||
anonymous_user_cookie
|
||||
)
|
||||
tenant_id = anonymous_user_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
if not tenant_id or not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
|
||||
@@ -36,6 +36,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
@@ -51,7 +54,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -94,7 +94,6 @@ async def get_or_provision_tenant(
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
@@ -506,11 +505,8 @@ async def setup_tenant(tenant_id: str) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations in a way that isolates it from the current event loop
|
||||
# Create a new event loop for this synchronous operation
|
||||
loop = asyncio.get_event_loop()
|
||||
# Use run_in_executor which properly isolates the thread execution
|
||||
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
|
||||
Binary file not shown.
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -9,7 +8,6 @@ import sentry_sdk
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
@@ -22,8 +20,6 @@ from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
@@ -40,12 +36,6 @@ transformer_logging.set_verbosity_error()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
file_handlers = [
|
||||
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
|
||||
]
|
||||
|
||||
setup_uvicorn_logger(shared_file_handlers=file_handlers)
|
||||
|
||||
|
||||
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
|
||||
"""
|
||||
@@ -122,15 +112,6 @@ def get_model_app() -> FastAPI:
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
request_id_prefix = "INF"
|
||||
if INDEXING_ONLY:
|
||||
request_id_prefix = "IDX"
|
||||
|
||||
add_onyx_request_id_middleware(application, request_id_prefix, logger)
|
||||
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
||||
@@ -57,9 +57,8 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_access = {}
|
||||
for document_id, user_emails, is_public in document_access_info:
|
||||
doc_access[document_id] = DocumentAccess.build(
|
||||
doc_access = {
|
||||
document_id: DocumentAccess.build(
|
||||
user_emails=[email for email in user_emails if email],
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=[],
|
||||
@@ -67,6 +66,8 @@ def _get_access_for_documents(
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not been indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
|
||||
@@ -15,22 +15,6 @@ class ExternalAccess:
|
||||
# Whether the document is public in the external system or Onyx
|
||||
is_public: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Prevent extremely long logs"""
|
||||
|
||||
def truncate_set(s: set[str], max_len: int = 100) -> str:
|
||||
s_str = str(s)
|
||||
if len(s_str) > max_len:
|
||||
return f"{s_str[:max_len]}... ({len(s)} items)"
|
||||
return s_str
|
||||
|
||||
return (
|
||||
f"ExternalAccess("
|
||||
f"external_user_emails={truncate_set(self.external_user_emails)}, "
|
||||
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
|
||||
f"is_public={self.is_public})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
|
||||
@@ -321,10 +321,8 @@ def dispatch_separated(
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
accumulated_tokens = ""
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
for token in tokens:
|
||||
accumulated_tokens += cast(str, token.content)
|
||||
content = cast(str, token.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
@@ -56,7 +56,6 @@ from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
@@ -361,6 +360,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -514,25 +514,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Optional[Request] = None,
|
||||
response: Optional[Response] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
|
||||
response.delete_cookie(
|
||||
ANONYMOUS_USER_COOKIE_NAME,
|
||||
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
|
||||
path="/",
|
||||
domain=None,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
|
||||
except Exception:
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -1322,7 +1303,6 @@ def get_oauth_router(
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
@@ -1332,14 +1312,9 @@ def get_oauth_router(
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
# FastAPI can have multiple Set-Cookie headers as a list
|
||||
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
|
||||
for cookie_value in header_value:
|
||||
redirect_response.headers.append(header_name, cookie_value)
|
||||
else:
|
||||
redirect_response.headers[header_name] = header_value
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -306,7 +305,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from celery import Celery
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.client")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
@@ -111,7 +111,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -96,7 +95,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
@@ -175,9 +174,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
logger.exception(
|
||||
f"Marking attempt {attempt.id} as canceled due to validation error 2"
|
||||
)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@@ -289,6 +285,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
@@ -64,15 +64,6 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-folder-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
|
||||
@@ -886,8 +886,11 @@ def monitor_ccpair_permissions_taskset(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"total_docs_synced": initial if initial is not None else 0,
|
||||
"remaining_docs_to_sync": remaining,
|
||||
"id": payload.id if payload else None,
|
||||
"total_docs": initial if initial is not None else 0,
|
||||
"remaining_docs": remaining,
|
||||
"synced_docs": (initial - remaining) if initial is not None else 0,
|
||||
"is_complete": remaining == 0,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
@@ -903,13 +906,6 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing complete
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={"cc_pair_id": cc_pair_id},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
|
||||
@@ -365,7 +365,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
@@ -434,9 +433,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(
|
||||
db_session, include_user_files=True
|
||||
)
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
@@ -455,18 +452,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
not search_settings_instance.status.is_current()
|
||||
and not search_settings_instance.background_reindex_enabled
|
||||
):
|
||||
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
|
||||
|
||||
continue
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Skipping fenced connector: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -474,9 +465,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
@@ -490,20 +478,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -542,12 +517,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Failed to create indexing task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -1180,9 +1149,6 @@ def connector_indexing_proxy_task(
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
|
||||
@@ -371,7 +371,6 @@ def should_index(
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
@@ -381,9 +380,6 @@ def should_index(
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
|
||||
)
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
@@ -392,31 +388,19 @@ def should_index(
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -428,9 +412,6 @@ def should_index(
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
|
||||
)
|
||||
return False
|
||||
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -443,16 +424,11 @@ def should_index(
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
|
||||
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -532,13 +508,6 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
@@ -549,7 +518,7 @@ def try_creating_indexing_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from tenacity import wait_random_exponential
|
||||
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
|
||||
|
||||
class RetryDocumentIndex:
|
||||
@@ -53,13 +52,11 @@ class RetryDocumentIndex:
|
||||
*,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=fields,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
@@ -164,7 +164,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_with_user_files,
|
||||
)
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.user_documents import fetch_user_files_for_documents
|
||||
from onyx.db.user_documents import fetch_user_folders_for_documents
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check for documents that need user file folder metadata updates.
|
||||
This task fetches all connector credential pairs with user files, gets the documents
|
||||
associated with them, and updates the user file and folder metadata in Vespa.
|
||||
"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all connector credential pairs that have user files
|
||||
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
|
||||
|
||||
if not cc_pairs:
|
||||
task_logger.info("No connector credential pairs with user files found")
|
||||
return True
|
||||
|
||||
# Get all documents associated with these cc_pairs
|
||||
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No documents found for connector credential pairs with user files"
|
||||
)
|
||||
return True
|
||||
|
||||
# Fetch current user file and folder IDs for these documents
|
||||
doc_id_to_user_file_id = fetch_user_files_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
# Update Vespa metadata for each document
|
||||
for doc_id in document_ids:
|
||||
user_file_id = doc_id_to_user_file_id.get(doc_id)
|
||||
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
|
||||
|
||||
if user_file_id is not None or user_folder_id is not None:
|
||||
# Schedule a task to update the document metadata
|
||||
update_user_file_folder_metadata.apply_async(
|
||||
args=(doc_id,), # Use tuple instead of list for args
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"user_file_id": user_file_id,
|
||||
"user_folder_id": user_folder_id,
|
||||
},
|
||||
queue="vespa_metadata_sync",
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Scheduled metadata updates for {len(document_ids)} documents. "
|
||||
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
|
||||
return False
|
||||
finally:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def get_documents_for_cc_pairs(
|
||||
cc_pairs: List[ConnectorCredentialPair], db_session: Session
|
||||
) -> List[str]:
|
||||
"""Get all document IDs associated with the given connector credential pairs."""
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
|
||||
|
||||
# Query to get document IDs from DocumentByConnectorCredentialPair
|
||||
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids),
|
||||
ConnectorCredentialPair.connector_id
|
||||
== DocumentByConnectorCredentialPair.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [doc_id for (doc_id,) in doc_cc_pairs]
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def update_user_file_folder_metadata(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_file_id: int | None,
|
||||
user_folder_id: int | None,
|
||||
) -> bool:
|
||||
"""Updates the user file and folder metadata for a document in Vespa."""
|
||||
start = time.monotonic()
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
return False
|
||||
|
||||
# Create user fields object with file and folder IDs
|
||||
user_fields = VespaDocumentUserFields(
|
||||
user_file_id=str(user_file_id) if user_file_id is not None else None,
|
||||
user_folder_id=str(user_folder_id)
|
||||
if user_folder_id is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=None, # We're only updating user fields
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=user_file_folder_sync "
|
||||
f"user_file_id={user_file_id} "
|
||||
f"user_folder_id={user_folder_id} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
return True
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
task_logger.exception(
|
||||
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -80,8 +80,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
# Useful for debugging timing issues with reacquisitions.
|
||||
# TODO: remove once more generalized logging is in place
|
||||
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
|
||||
task_logger.info("check_for_vespa_sync_task started")
|
||||
|
||||
time_start = time.monotonic()
|
||||
@@ -573,7 +572,6 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker.
|
||||
|
||||
This is an app stub purely for sending tasks as a client.
|
||||
"""
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.client import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -56,6 +56,7 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@@ -273,6 +274,7 @@ def _run_indexing(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
# search_settings = index_attempt_start.search_settings
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
ctx = RunIndexingContext(
|
||||
@@ -577,8 +579,11 @@ def _run_indexing(
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"current_docs_indexed": document_count,
|
||||
"current_chunks_indexed": chunk_count,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_num": batch_num,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
@@ -599,15 +604,26 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Add telemetry for completed indexing
|
||||
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt_start.search_settings_id
|
||||
)
|
||||
final_progress = redis_connector_index.get_progress() or 0
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_count": batch_num,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
"redis_progress": final_progress,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
@@ -622,9 +638,6 @@ def _run_indexing(
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -671,9 +684,6 @@ def _run_indexing(
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -736,7 +746,6 @@ def _run_indexing(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
|
||||
@@ -127,10 +127,6 @@ class StreamStopInfo(SubQuestionIdentifier):
|
||||
return data
|
||||
|
||||
|
||||
class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
@@ -36,14 +36,12 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
|
||||
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -53,7 +51,6 @@ from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
@@ -67,7 +64,6 @@ from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import attach_files_to_chat_message
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_db_search_doc_by_id
|
||||
@@ -84,16 +80,12 @@ from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_all_user_file_files
|
||||
from onyx.file_store.utils import load_all_user_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -106,7 +98,6 @@ from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
@@ -184,14 +175,11 @@ def _handle_search_tool_response_summary(
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
@@ -205,31 +193,9 @@ def _handle_search_tool_response_summary(
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id not in doc_ids:
|
||||
associated_chat_file = None
|
||||
if loaded_user_files is not None:
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
@@ -287,10 +253,7 @@ def _handle_internet_search_tool_response_summary(
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
tools: list[Tool],
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
@@ -298,11 +261,8 @@ def _get_force_search_settings(
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
if new_msg_req.force_user_file_search:
|
||||
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
|
||||
else:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
@@ -312,25 +272,12 @@ def _get_force_search_settings(
|
||||
else None
|
||||
)
|
||||
|
||||
# Create override_kwargs for the search tool if user_file_ids are provided
|
||||
override_kwargs = None
|
||||
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.force_user_file_search,
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
@@ -343,17 +290,9 @@ def _get_force_search_settings(
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
|
||||
)
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -372,7 +311,6 @@ ChatPacket = (
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
| UserKnowledgeFilePacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -418,10 +356,6 @@ def stream_chat_message_objects(
|
||||
llm: LLM
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
ordered_user_files = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
@@ -601,70 +535,6 @@ def stream_chat_message_objects(
|
||||
)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
user_folder_ids = new_msg_req.user_folder_ids or []
|
||||
|
||||
if persona.user_files:
|
||||
for file in persona.user_files:
|
||||
user_file_ids.append(file.id)
|
||||
if persona.user_folders:
|
||||
for folder in persona.user_folders:
|
||||
user_folder_ids.append(folder.id)
|
||||
|
||||
# Initialize flag for user file search
|
||||
use_search_for_user_files = False
|
||||
|
||||
user_files: list[InMemoryChatFile] | None = None
|
||||
search_for_ordering_only = False
|
||||
user_file_files: list[UserFile] | None = None
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Load user files
|
||||
user_files = load_all_user_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
user_file_files = load_all_user_file_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
# Store mapping of file_id to file for later reordering
|
||||
if user_files:
|
||||
file_id_to_user_file = {file.file_id: file for file in user_files}
|
||||
|
||||
# Calculate token count for the files
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
|
||||
total_tokens = calculate_user_files_token_count(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate available tokens for documents based on prompt, user input, etc.
|
||||
available_tokens = compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=message_text, # Use the actual user message
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
|
||||
)
|
||||
|
||||
# ALWAYS use search for user files, but track if we need it for context or just ordering
|
||||
use_search_for_user_files = True
|
||||
# If files are small enough for context, we'll just use search for ordering
|
||||
search_for_ordering_only = total_tokens <= available_tokens
|
||||
|
||||
if search_for_ordering_only:
|
||||
# Add original user files to context since they fit
|
||||
if user_files:
|
||||
latest_query_files.extend(user_files)
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -693,13 +563,8 @@ def stream_chat_message_objects(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# Add a maximum context size in the case of user-selected docs to prevent
|
||||
# slight inaccuracies in context window size pruning from causing
|
||||
# the entire query to fail
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
is_manually_selected_docs=True,
|
||||
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
|
||||
is_manually_selected_docs=True
|
||||
)
|
||||
|
||||
# In case the search doc is deleted, just don't include it
|
||||
@@ -812,10 +677,8 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
user_knowledge_present=bool(user_files or user_folder_ids),
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
use_file_search=new_msg_req.force_user_file_search,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
@@ -845,138 +708,17 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
force_use_tool = _get_force_search_settings(
|
||||
new_msg_req, tools, user_file_ids, user_folder_ids
|
||||
)
|
||||
|
||||
# Set force_use if user files exceed token limit
|
||||
if use_search_for_user_files:
|
||||
try:
|
||||
# Check if search tool is available in the tools list
|
||||
search_tool_available = any(
|
||||
isinstance(tool, SearchTool) for tool in tools
|
||||
)
|
||||
|
||||
# If no search tool is available, add one
|
||||
if not search_tool_available:
|
||||
logger.info("No search tool available, creating one for user files")
|
||||
# Create a basic search tool config
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
)
|
||||
|
||||
# Create and add the search tool
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
# Add the search tool to the tools list
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.info(
|
||||
"Added search tool for user files that exceed token limit"
|
||||
)
|
||||
|
||||
# Now set force_use_tool.force_use to True
|
||||
force_use_tool.force_use = True
|
||||
force_use_tool.tool_name = SearchTool._NAME
|
||||
|
||||
# Set query argument if not already set
|
||||
if not force_use_tool.args:
|
||||
force_use_tool.args = {"query": final_msg.message}
|
||||
|
||||
# Pass the user file IDs to the search tool
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Create a BaseFilters object with user_file_ids
|
||||
if not retrieval_options:
|
||||
retrieval_options = RetrievalDetails()
|
||||
if not retrieval_options.filters:
|
||||
retrieval_options.filters = BaseFilters()
|
||||
|
||||
# Set user file and folder IDs in the filters
|
||||
retrieval_options.filters.user_file_ids = user_file_ids
|
||||
retrieval_options.filters.user_folder_ids = user_folder_ids
|
||||
|
||||
# Create override kwargs for the search tool
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
|
||||
)
|
||||
|
||||
# Set the override kwargs in the force_use_tool
|
||||
force_use_tool.override_kwargs = override_kwargs
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Configured search tool with optimized settings for ordering-only"
|
||||
)
|
||||
logger.info(
|
||||
"Fast path: Skipping reranking and query analysis for ordering-only mode"
|
||||
)
|
||||
logger.info(
|
||||
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Configured search tool to use ",
|
||||
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error configuring search tool for user files: {str(e)}"
|
||||
)
|
||||
use_search_for_user_files = False
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
if not use_search_for_user_files and user_files:
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
|
||||
)
|
||||
for file in user_files
|
||||
]
|
||||
)
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
|
||||
)
|
||||
|
||||
search_request = SearchRequest(
|
||||
query=final_msg.message,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if search_for_ordering_only
|
||||
else (
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
)
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
@@ -995,6 +737,7 @@ def stream_chat_message_objects(
|
||||
),
|
||||
)
|
||||
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
@@ -1063,22 +806,8 @@ def stream_chat_message_objects(
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
continue
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
@@ -1088,91 +817,16 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=(
|
||||
False
|
||||
if search_for_ordering_only
|
||||
else (
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
)
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=user_files
|
||||
if search_for_ordering_only
|
||||
else [],
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if (
|
||||
search_for_ordering_only
|
||||
and user_files
|
||||
and info.qa_docs_response
|
||||
):
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
import time
|
||||
|
||||
ordering_start = time.time()
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Found {len(doc_order)} files from search results"
|
||||
)
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id
|
||||
for f_id in file_id_to_user_file.keys()
|
||||
if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Added {len(missing_files)} missing files to the end"
|
||||
)
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id]
|
||||
for f_id in doc_order
|
||||
if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
time.time() - ordering_start
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
continue
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
@@ -1282,7 +936,7 @@ def stream_chat_message_objects(
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
logger.debug("Reached end of stream")
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -1364,16 +1018,10 @@ def stream_chat_message_objects(
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id.get(info.tool_result.tool_name, 0)
|
||||
if info.tool_result
|
||||
else None,
|
||||
tool_name=info.tool_result.tool_name if info.tool_result else None,
|
||||
tool_arguments=info.tool_result.tool_args
|
||||
if info.tool_result
|
||||
else None,
|
||||
tool_result=info.tool_result.tool_result
|
||||
if info.tool_result
|
||||
else None,
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
tool_name=info.tool_result.tool_name,
|
||||
tool_arguments=info.tool_result.tool_args,
|
||||
tool_result=info.tool_result.tool_result,
|
||||
)
|
||||
if info.tool_result
|
||||
else None
|
||||
|
||||
@@ -19,7 +19,6 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
|
||||
@@ -153,8 +153,6 @@ def _apply_pruning(
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
sections = _remove_sections_to_ignore(sections=sections)
|
||||
|
||||
section_idx_token_count: dict[int, int] = {}
|
||||
|
||||
final_section_ind = None
|
||||
total_tokens = 0
|
||||
for ind, section in enumerate(sections):
|
||||
@@ -204,20 +202,10 @@ def _apply_pruning(
|
||||
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_token_count
|
||||
section_idx_token_count[ind] = section_token_count
|
||||
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
|
||||
try:
|
||||
logger.debug(f"Number of documents after pruning: {ind + 1}")
|
||||
logger.debug("Number of tokens per document (pruned):")
|
||||
for x, y in section_idx_token_count.items():
|
||||
logger.debug(f"{x + 1}: {y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging prune statistics: {e}")
|
||||
|
||||
if final_section_ind is not None:
|
||||
if is_manually_selected_docs or use_sections:
|
||||
if final_section_ind != len(sections) - 1:
|
||||
@@ -312,14 +300,11 @@ def prune_sections(
|
||||
)
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
assert (
|
||||
len(set([chunk.document_id for chunk in chunks])) == 1
|
||||
), "One distinct document must be passed into merge_doc_chunks"
|
||||
|
||||
ADJACENT_CHUNK_SEP = "\n"
|
||||
DISTANT_CHUNK_SEP = "\n\n...\n\n"
|
||||
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
@@ -327,48 +312,33 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, i
|
||||
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
|
||||
)
|
||||
|
||||
added_chars = 0
|
||||
merged_content = []
|
||||
for i, chunk in enumerate(sorted_chunks):
|
||||
if i > 0:
|
||||
prev_chunk_id = sorted_chunks[i - 1].chunk_id
|
||||
sep = (
|
||||
ADJACENT_CHUNK_SEP
|
||||
if chunk.chunk_id == prev_chunk_id + 1
|
||||
else DISTANT_CHUNK_SEP
|
||||
)
|
||||
merged_content.append(sep)
|
||||
added_chars += len(sep)
|
||||
if chunk.chunk_id == prev_chunk_id + 1:
|
||||
merged_content.append("\n")
|
||||
else:
|
||||
merged_content.append("\n\n...\n\n")
|
||||
merged_content.append(chunk.content)
|
||||
|
||||
combined_content = "".join(merged_content)
|
||||
|
||||
return (
|
||||
InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
),
|
||||
added_chars,
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
)
|
||||
|
||||
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
# chunk de-duping and doc ordering
|
||||
for index, section in enumerate(sections):
|
||||
if section.center_chunk.document_id not in doc_order:
|
||||
doc_order[section.center_chunk.document_id] = index
|
||||
|
||||
combined_section_lengths[section.center_chunk.document_id] += len(
|
||||
section.combined_content
|
||||
)
|
||||
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
for chunk in [section.center_chunk] + section.chunks:
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
existing_chunk = chunks_map.get(chunk.chunk_id)
|
||||
if (
|
||||
existing_chunk is None
|
||||
@@ -379,22 +349,8 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
chunks_map[chunk.chunk_id] = chunk
|
||||
|
||||
new_sections = []
|
||||
for doc_id, section_chunks in docs_map.items():
|
||||
section_chunks_list = list(section_chunks.values())
|
||||
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
|
||||
|
||||
previous_length = combined_section_lengths[doc_id] + added_chars
|
||||
# After merging, ensure the content respects the pruning done earlier. Each
|
||||
# combined section is restricted to the sum of the lengths of the sections
|
||||
# from the pruning step. Technically the correct approach would be to prune based
|
||||
# on tokens AGAIN, but this is a good approximation and worth not adding the
|
||||
# tokenization overhead. This could also be fixed if we added a way of removing
|
||||
# chunks from sections in the pruning step; at the moment this issue largely
|
||||
# exists because we only trim the final section's combined_content.
|
||||
merged_section.combined_content = merged_section.combined_content[
|
||||
:previous_length
|
||||
]
|
||||
new_sections.append(merged_section)
|
||||
for section_chunks in docs_map.values():
|
||||
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
|
||||
|
||||
# Sort by highest score, then by original document order
|
||||
# It is now 1 large section per doc, the center chunk being the one with the highest score
|
||||
@@ -406,26 +362,6 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
try:
|
||||
num_original_sections = len(sections)
|
||||
num_original_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in sections])
|
||||
)
|
||||
num_merged_sections = len(new_sections)
|
||||
num_merged_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in new_sections])
|
||||
)
|
||||
logger.debug(
|
||||
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
|
||||
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
|
||||
)
|
||||
|
||||
logger.debug("Number of chunks per document (new ranking):")
|
||||
for x, y in enumerate(new_sections):
|
||||
logger.debug(f"{x + 1}: {len(y.chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging merge statistics: {e}")
|
||||
|
||||
return new_sections
|
||||
|
||||
|
||||
|
||||
@@ -180,10 +180,6 @@ def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
if force_use_tool.override_kwargs is not None:
|
||||
tool_args["override_kwargs"] = force_use_tool.override_kwargs
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
|
||||
@@ -170,7 +170,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
@@ -437,7 +437,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
@@ -495,11 +495,6 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
)
|
||||
# Enable contextual retrieval
|
||||
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
|
||||
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
@@ -541,17 +536,6 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
|
||||
# Average summary embeddings for contextual rag (not yet implemented)
|
||||
AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
@@ -16,9 +16,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
|
||||
@@ -102,8 +102,6 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
@@ -271,7 +269,6 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@@ -312,7 +309,6 @@ class OnyxCeleryQueues:
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
@@ -331,7 +327,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
@@ -402,7 +397,6 @@ class OnyxCeleryTask:
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
|
||||
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -411,7 +405,6 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -343,14 +342,9 @@ def build_confluence_document_id(
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
|
||||
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
|
||||
# end with "/" because it believes that makes it a file.
|
||||
final_url = base_url.rstrip("/") + "/"
|
||||
if is_cloud and not final_url.endswith("/wiki/"):
|
||||
final_url = urljoin(final_url, "wiki") + "/"
|
||||
final_url = urljoin(final_url, content_url.lstrip("/"))
|
||||
return final_url
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
@@ -460,19 +454,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
# Confluence Server returns 403 when rate limited
|
||||
if e.response.status_code == 403:
|
||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||
FORBIDDEN_RETRY_DELAY = 10
|
||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||
logger.warning(
|
||||
"403 error. This sometimes happens when we hit "
|
||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
||||
)
|
||||
return FORBIDDEN_RETRY_DELAY
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -45,8 +45,6 @@ _FIREFLIES_API_QUERY = """
|
||||
}
|
||||
"""
|
||||
|
||||
ONE_MINUTE = 60
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections: List[TextSection] = []
|
||||
@@ -108,8 +106,6 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
)
|
||||
|
||||
|
||||
# If not all transcripts are being indexed, try using a more-recently-generated
|
||||
# API key.
|
||||
class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
@@ -195,9 +191,6 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# add some leeway to account for any timezone funkiness and/or bad handling
|
||||
# of start time on the Fireflies side
|
||||
start = max(0, start - ONE_MINUTE)
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
@@ -276,26 +276,7 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
return checkpoint
|
||||
|
||||
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try direct access to a known attribute name first
|
||||
if hasattr(self.github_client, "_requester"):
|
||||
requester = self.github_client._requester
|
||||
elif hasattr(self.github_client, "_Github__requester"):
|
||||
requester = self.github_client._Github__requester
|
||||
else:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
repo = checkpoint.cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = checkpoint.cached_repo.id
|
||||
repo = self.github_client.get_repo(repo_id)
|
||||
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
|
||||
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
|
||||
@@ -445,9 +445,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
logger.warning(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
# mark this user as done so we don't try to retrieve anything for them
|
||||
# again
|
||||
curr_stage.stage = DriveRetrievalStage.DONE
|
||||
return
|
||||
raise
|
||||
|
||||
@@ -584,25 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
# only process emails that we haven't already completed retrieval for
|
||||
non_completed_org_emails = [
|
||||
user_email
|
||||
for user_email, stage in checkpoint.completion_map.items()
|
||||
if stage != DriveRetrievalStage.DONE
|
||||
]
|
||||
|
||||
# don't process too many emails before returning a checkpoint. This is
|
||||
# to resolve the case where there are a ton of emails that don't have access
|
||||
# to the drive APIs. Without this, we could loop through these emails for
|
||||
# more than 3 hours, causing a timeout and stalling progress.
|
||||
email_batch_takes_us_to_completion = True
|
||||
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
|
||||
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
|
||||
non_completed_org_emails = non_completed_org_emails[
|
||||
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
|
||||
]
|
||||
email_batch_takes_us_to_completion = False
|
||||
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
@@ -613,14 +591,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
start,
|
||||
end,
|
||||
)
|
||||
for email in non_completed_org_emails
|
||||
for email in all_org_emails
|
||||
]
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
|
||||
|
||||
# if there are more emails to process, don't mark as complete
|
||||
if not email_batch_takes_us_to_completion:
|
||||
return
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
) - self._retrieved_ids
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -77,26 +76,6 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
@@ -135,31 +114,41 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
response_call = lazy_eval(lambda: download_request(service, file_id))
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response_call().decode("utf-8")
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response_call()))
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response_call()))
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
@@ -169,7 +158,7 @@ def _download_and_extract_sections_basic(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response_call(),
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
@@ -182,7 +171,7 @@ def _download_and_extract_sections_basic(
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
@@ -205,15 +194,8 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
if mime_type in [
|
||||
"application/vnd.google-apps.video",
|
||||
"application/vnd.google-apps.audio",
|
||||
"application/zip",
|
||||
]:
|
||||
return []
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
|
||||
@@ -75,7 +75,7 @@ class HighspotClient:
|
||||
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.base_url = base_url.rstrip("/") + "/"
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retry logic
|
||||
|
||||
@@ -163,9 +163,6 @@ class DocumentBase(BaseModel):
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
@@ -14,8 +14,6 @@ from typing import cast
|
||||
from pydantic import BaseModel
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
@@ -28,8 +26,6 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -42,16 +38,15 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
@@ -498,13 +493,9 @@ def _process_message(
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(
|
||||
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
|
||||
):
|
||||
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
FAST_TIMEOUT = 1
|
||||
|
||||
MAX_RETRIES = 7 # arbitrarily selected
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -523,49 +514,16 @@ class SlackConnector(
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
self.credential_prefix: str | None = None
|
||||
self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
credentials = credentials_provider.get_credentials()
|
||||
tenant_id = credentials_provider.get_tenant_id()
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.credential_prefix = (
|
||||
f"connector:slack:credential_{credentials_provider.get_provider_key()}"
|
||||
)
|
||||
self.delay_lock = f"{self.credential_prefix}:delay_lock"
|
||||
self.delay_key = f"{self.credential_prefix}:delay"
|
||||
|
||||
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
|
||||
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
|
||||
max_retry_count=self.MAX_RETRIES,
|
||||
delay_lock=self.delay_lock,
|
||||
delay_key=self.delay_key,
|
||||
r=self.redis,
|
||||
)
|
||||
custom_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
onyx_rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers)
|
||||
self.client = WebClient(token=bot_token)
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
self.credentials_provider = credentials_provider
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk.http_retry.handler import RetryHandler
|
||||
from slack_sdk.http_retry.request import HttpRequest
|
||||
from slack_sdk.http_retry.response import HttpResponse
|
||||
from slack_sdk.http_retry.state import RetryState
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
This class uses Redis to share a rate limit among multiple threads.
|
||||
|
||||
Threads that encounter a rate limit will observe the shared delay, increment the
|
||||
shared delay with the retry value, and use the new shared value as a wait interval.
|
||||
|
||||
This has the effect of serializing calls when a rate limit is hit, which is what
|
||||
needs to happens if the server punishes us with additional limiting when we make
|
||||
a call too early. We believe this is what Slack is doing based on empirical
|
||||
observation, meaning we see indefinite hangs if we're too aggressive.
|
||||
|
||||
Another way to do this is just to do exponential backoff. Might be easier?
|
||||
|
||||
Adapted from slack's RateLimitErrorRetryHandler.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 60 # used to serialize access to the retry TTL
|
||||
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
|
||||
|
||||
"""RetryHandler that does retries for rate limited errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retry_count: int,
|
||||
delay_lock: str,
|
||||
delay_key: str,
|
||||
r: Redis,
|
||||
):
|
||||
"""
|
||||
delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key)
|
||||
delay_key: the redis key containing a shared TTL
|
||||
"""
|
||||
super().__init__(max_retry_count=max_retry_count)
|
||||
self._redis: Redis = r
|
||||
self._delay_lock = delay_lock
|
||||
self._delay_key = delay_key
|
||||
|
||||
def _can_retry(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> bool:
|
||||
return response is not None and response.status_code == 429
|
||||
|
||||
def prepare_for_next_attempt(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""It seems this function is responsible for the wait to retry ... aka we
|
||||
actually sleep in this function."""
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
|
||||
if response is None:
|
||||
# NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler.
|
||||
# This reads oddly, as if the caller itself could raise the exception.
|
||||
# We don't have the luxury of changing this.
|
||||
if error:
|
||||
raise error
|
||||
|
||||
return
|
||||
|
||||
state.next_attempt_requested = True # this signals the caller to retry
|
||||
|
||||
# calculate wait duration based on retry-after + some jitter
|
||||
for k in response.headers.keys():
|
||||
if k.lower() == "retry-after":
|
||||
retry_after_header_name = k
|
||||
break
|
||||
|
||||
try:
|
||||
if retry_after_header_name is None:
|
||||
# This situation usually does not arise. Just in case.
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
|
||||
)
|
||||
|
||||
retry_after_value = response.headers.get(retry_after_header_name)
|
||||
if not retry_after_value:
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
|
||||
)
|
||||
|
||||
retry_after_value_int = int(
|
||||
retry_after_value[0]
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = math.ceil(retry_after_value_int + jitter)
|
||||
except ValueError:
|
||||
duration_s += random.random()
|
||||
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
|
||||
)
|
||||
|
||||
ttl_ms: int | None = None
|
||||
|
||||
try:
|
||||
if acquired:
|
||||
# if we can get the lock, then read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
else:
|
||||
# if we can't get the lock, just go ahead.
|
||||
# TODO: if we know our actual parallelism, multiplying by that
|
||||
# would be a pretty good idea
|
||||
ttl_ms_new = int(duration_s * 1000.0)
|
||||
finally:
|
||||
if acquired:
|
||||
lock.release()
|
||||
|
||||
logger.warning(
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
|
||||
f"retry-after={retry_after_value} "
|
||||
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
|
||||
)
|
||||
|
||||
# TODO: would be good to take an event var and sleep in short increments to
|
||||
# allow for a clean exit / exception
|
||||
time.sleep(ttl_ms_new / 1000.0)
|
||||
|
||||
state.increment_current_attempt()
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
@@ -63,72 +64,71 @@ def _make_slack_api_call_paginated(
|
||||
return paginated_call
|
||||
|
||||
|
||||
# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as
|
||||
# expected. Do we want to keep this around?
|
||||
def make_slack_api_rate_limited(
|
||||
call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
# def make_slack_api_rate_limited(
|
||||
# call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
# ) -> Callable[..., SlackResponse]:
|
||||
# """Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
|
||||
# @wraps(call)
|
||||
# def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
# last_exception = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
response = call(**kwargs)
|
||||
|
||||
# for _ in range(max_retries):
|
||||
# try:
|
||||
# # Make the API call
|
||||
# response = call(**kwargs)
|
||||
# Check for errors in the response, will raise `SlackApiError`
|
||||
# if anything went wrong
|
||||
response.validate()
|
||||
return response
|
||||
|
||||
# # Check for errors in the response, will raise `SlackApiError`
|
||||
# # if anything went wrong
|
||||
# response.validate()
|
||||
# return response
|
||||
except SlackApiError as e:
|
||||
last_exception = e
|
||||
try:
|
||||
error = e.response["error"]
|
||||
except KeyError:
|
||||
error = "unknown error"
|
||||
|
||||
# except SlackApiError as e:
|
||||
# last_exception = e
|
||||
# try:
|
||||
# error = e.response["error"]
|
||||
# except KeyError:
|
||||
# error = "unknown error"
|
||||
if error == "ratelimited":
|
||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.info(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# Log internal_error and return the response instead of failing
|
||||
logger.warning(
|
||||
f"Slack call encountered '{error}', skipping and continuing..."
|
||||
)
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
raise
|
||||
|
||||
# if error == "ratelimited":
|
||||
# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
# retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
# logger.info(
|
||||
# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
# )
|
||||
# time.sleep(retry_after)
|
||||
# elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# # Log internal_error and return the response instead of failing
|
||||
# logger.warning(
|
||||
# f"Slack call encountered '{error}', skipping and continuing..."
|
||||
# )
|
||||
# return e.response
|
||||
# else:
|
||||
# # Raise the error for non-transient errors
|
||||
# raise
|
||||
# If the code reaches this point, all retries have been exhausted
|
||||
msg = f"Max retries ({max_retries}) exceeded"
|
||||
if last_exception:
|
||||
raise Exception(msg) from last_exception
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
# # If the code reaches this point, all retries have been exhausted
|
||||
# msg = f"Max retries ({max_retries}) exceeded"
|
||||
# if last_exception:
|
||||
# raise Exception(msg) from last_exception
|
||||
# else:
|
||||
# raise Exception(msg)
|
||||
|
||||
# return rate_limited_call
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(call)(**kwargs)
|
||||
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(make_slack_api_rate_limited(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
@@ -142,7 +142,7 @@ def expert_info_from_slack_id(
|
||||
if user_id in user_cache:
|
||||
return user_cache[user_id]
|
||||
|
||||
response = client.users_info(user=user_id)
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
|
||||
if not response["ok"]:
|
||||
user_cache[user_id] = None
|
||||
@@ -175,7 +175,9 @@ class SlackTextCleaner:
|
||||
def _get_slack_name(self, user_id: str) -> str:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = self._client.users_info(user=user_id)
|
||||
response = make_slack_api_rate_limited(self._client.users_info)(
|
||||
user=user_id
|
||||
)
|
||||
# prefer display name if set, since that is what is shown in Slack
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
|
||||
@@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
|
||||
inference_settings = InferenceSettings.from_db_model(search_settings)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
|
||||
return cls(**inference_settings.dict(), **indexing_setting.dict())
|
||||
|
||||
|
||||
class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
@@ -80,9 +80,6 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
# Whether switching to this model requires re-indexing
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
@@ -105,8 +102,6 @@ class BaseFilters(BaseModel):
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
@@ -223,8 +218,6 @@ class InferenceChunk(BaseChunk):
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
|
||||
match_highlights: list[str]
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
|
||||
@@ -5,13 +5,11 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.chat.prune_and_merge import ChunkRange
|
||||
from onyx.chat.prune_and_merge import merge_chunk_intervals
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
@@ -63,7 +61,6 @@ class SearchPipeline:
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
contextual_pruning_config: ContextualPruningConfig | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
@@ -80,9 +77,6 @@ class SearchPipeline:
|
||||
self.search_settings = get_current_search_settings(db_session)
|
||||
self.document_index = get_default_document_index(self.search_settings, None)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
self.contextual_pruning_config: ContextualPruningConfig | None = (
|
||||
contextual_pruning_config
|
||||
)
|
||||
|
||||
# Preprocessing steps generate this
|
||||
self._search_query: SearchQuery | None = None
|
||||
@@ -164,47 +158,6 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
def get_ordering_only_chunks(
|
||||
self,
|
||||
query: str,
|
||||
user_file_ids: list[int] | None = None,
|
||||
user_folder_ids: list[int] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Optimized method that only retrieves chunks for ordering purposes.
|
||||
Skips all extra processing and uses minimal configuration to speed up retrieval.
|
||||
"""
|
||||
logger.info("Fast path: Using optimized chunk retrieval for ordering-only mode")
|
||||
|
||||
# Create minimal filters with just user file/folder IDs
|
||||
filters = IndexFilters(
|
||||
user_file_ids=user_file_ids or [],
|
||||
user_folder_ids=user_folder_ids or [],
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
# Use a simplified query that skips all unnecessary processing
|
||||
minimal_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=SearchType.SEMANTIC,
|
||||
filters=filters,
|
||||
# Set minimal options needed for retrieval
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
recency_bias_multiplier=1.0,
|
||||
chunks_above=0, # No need for surrounding context
|
||||
chunks_below=0, # No need for surrounding context
|
||||
processed_keywords=[], # Empty list instead of None
|
||||
rerank_settings=None,
|
||||
hybrid_alpha=0.0,
|
||||
max_llm_filter_sections=0,
|
||||
)
|
||||
|
||||
# Retrieve chunks using the minimal configuration
|
||||
return retrieve_chunks(
|
||||
query=minimal_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
@@ -227,16 +180,13 @@ class SearchPipeline:
|
||||
|
||||
# If ee is enabled, censor the chunk sections based on user access
|
||||
# Otherwise, return the retrieved chunks
|
||||
censored_chunks = cast(
|
||||
list[InferenceChunk],
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
),
|
||||
censored_chunks = fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
above = self.search_query.chunks_above
|
||||
@@ -429,26 +379,7 @@ class SearchPipeline:
|
||||
if self._final_context_sections is not None:
|
||||
return self._final_context_sections
|
||||
|
||||
if (
|
||||
self.contextual_pruning_config is not None
|
||||
and self.prompt_config is not None
|
||||
):
|
||||
self._final_context_sections = prune_and_merge_sections(
|
||||
sections=self.reranked_sections,
|
||||
section_relevance_list=None,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm.config,
|
||||
question=self.search_query.query,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
"Contextual pruning or prompt config not set, using default merge"
|
||||
)
|
||||
self._final_context_sections = _merge_sections(
|
||||
sections=self.reranked_sections
|
||||
)
|
||||
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
|
||||
return self._final_context_sections
|
||||
|
||||
@property
|
||||
@@ -460,10 +391,6 @@ class SearchPipeline:
|
||||
self.search_query.evaluation_type == LLMEvaluationType.SKIP
|
||||
or DISABLE_LLM_DOC_RELEVANCE
|
||||
):
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info(
|
||||
"Fast path: Skipping section relevance evaluation for ordering-only mode"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
|
||||
|
||||
@@ -11,7 +11,6 @@ from langchain_core.messages import SystemMessage
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
@@ -197,21 +196,9 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
|
||||
# remove document summary
|
||||
if chunk.content.startswith(chunk.doc_summary):
|
||||
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
|
||||
# remove chunk context
|
||||
if chunk.content.endswith(chunk.chunk_context):
|
||||
chunk.content = chunk.content[
|
||||
: len(chunk.content) - len(chunk.chunk_context)
|
||||
].rstrip()
|
||||
return chunk.content
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
chunk.content = _remove_contextual_rag(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
|
||||
@@ -367,21 +354,6 @@ def filter_sections(
|
||||
|
||||
Returns a list of the unique chunk IDs that were marked as relevant
|
||||
"""
|
||||
# Log evaluation type to help with debugging
|
||||
logger.info(f"filter_sections called with evaluation_type={query.evaluation_type}")
|
||||
|
||||
# Fast path: immediately return empty list for SKIP evaluation type (ordering-only mode)
|
||||
if query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
return []
|
||||
|
||||
# Additional safeguard: Log a warning if this function is ever called with SKIP evaluation type
|
||||
# This should never happen if our fast paths are working correctly
|
||||
if query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.warning(
|
||||
"WARNING: filter_sections called with SKIP evaluation_type. This should never happen!"
|
||||
)
|
||||
return []
|
||||
|
||||
sections_to_filter = sections_to_filter[: query.max_llm_filter_sections]
|
||||
|
||||
contents = [
|
||||
@@ -414,16 +386,6 @@ def search_postprocessing(
|
||||
llm: LLM,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Iterator[list[InferenceSection] | list[SectionRelevancePiece]]:
|
||||
# Fast path for ordering-only: detect it by checking if evaluation_type is SKIP
|
||||
if search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info(
|
||||
"Fast path: Detected ordering-only mode, bypassing all post-processing"
|
||||
)
|
||||
# Immediately yield the sections without any processing and an empty relevance list
|
||||
yield retrieved_sections
|
||||
yield cast(list[SectionRelevancePiece], [])
|
||||
return
|
||||
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
if not retrieved_sections:
|
||||
@@ -460,14 +422,10 @@ def search_postprocessing(
|
||||
sections_yielded = True
|
||||
|
||||
llm_filter_task_id = None
|
||||
# Only add LLM filtering if not in SKIP mode and if LLM doc relevance is not disabled
|
||||
if (
|
||||
search_query.evaluation_type not in [LLMEvaluationType.SKIP]
|
||||
and not DISABLE_LLM_DOC_RELEVANCE
|
||||
and search_query.evaluation_type
|
||||
in [LLMEvaluationType.BASIC, LLMEvaluationType.UNSPECIFIED]
|
||||
):
|
||||
logger.info("Adding LLM filtering task for document relevance evaluation")
|
||||
if search_query.evaluation_type in [
|
||||
LLMEvaluationType.BASIC,
|
||||
LLMEvaluationType.UNSPECIFIED,
|
||||
]:
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_sections,
|
||||
@@ -479,10 +437,6 @@ def search_postprocessing(
|
||||
)
|
||||
)
|
||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||
elif search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info("Fast path: Skipping LLM filtering task for ordering-only mode")
|
||||
elif DISABLE_LLM_DOC_RELEVANCE:
|
||||
logger.info("Skipping LLM filtering task because LLM doc relevance is disabled")
|
||||
|
||||
post_processing_results = (
|
||||
run_functions_in_parallel(post_processing_tasks)
|
||||
|
||||
@@ -165,18 +165,7 @@ def retrieval_preprocessing(
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
user_file_ids = preset_filters.user_file_ids or []
|
||||
user_folder_ids = preset_filters.user_folder_ids or []
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = user_file_ids + [
|
||||
file.id
|
||||
for file in persona.user_files
|
||||
if file.id not in (preset_filters.user_file_ids or [])
|
||||
]
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
@@ -45,11 +44,9 @@ from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_best_persona_id_for_user
|
||||
from onyx.db.pg_file_store import delete_lobj_by_name
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -857,87 +854,6 @@ def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | N
|
||||
return search_doc
|
||||
|
||||
|
||||
def create_search_doc_from_user_file(
|
||||
db_user_file: UserFile, associated_chat_file: InMemoryChatFile, db_session: Session
|
||||
) -> SearchDoc:
|
||||
"""Create a SearchDoc in the database from a UserFile and return it.
|
||||
This ensures proper ID generation by SQLAlchemy and prevents duplicate key errors.
|
||||
"""
|
||||
blurb = ""
|
||||
if associated_chat_file and associated_chat_file.content:
|
||||
try:
|
||||
# Try to decode as UTF-8, but handle errors gracefully
|
||||
content_sample = associated_chat_file.content[:100]
|
||||
# Remove null bytes which can cause SQL errors
|
||||
content_sample = content_sample.replace(b"\x00", b"")
|
||||
blurb = content_sample.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
# If decoding fails completely, provide a generic description
|
||||
blurb = f"[Binary file: {db_user_file.name}]"
|
||||
|
||||
db_search_doc = SearchDoc(
|
||||
document_id=db_user_file.document_id,
|
||||
chunk_ind=0, # Default to 0 for user files
|
||||
semantic_id=db_user_file.name,
|
||||
link=db_user_file.link_url,
|
||||
blurb=blurb,
|
||||
source_type=DocumentSource.FILE, # Assuming internal source for user files
|
||||
boost=0, # Default boost
|
||||
hidden=False, # Default visibility
|
||||
doc_metadata={}, # Empty metadata
|
||||
score=0.0, # Default score of 0.0 instead of None
|
||||
is_relevant=None, # No relevance initially
|
||||
relevance_explanation=None, # No explanation initially
|
||||
match_highlights=[], # No highlights initially
|
||||
updated_at=db_user_file.created_at, # Use created_at as updated_at
|
||||
primary_owners=[], # Empty list instead of None
|
||||
secondary_owners=[], # Empty list instead of None
|
||||
is_internet=False, # Not from internet
|
||||
)
|
||||
|
||||
db_session.add(db_search_doc)
|
||||
db_session.flush() # Get the ID but don't commit yet
|
||||
|
||||
return db_search_doc
|
||||
|
||||
|
||||
def translate_db_user_file_to_search_doc(
|
||||
db_user_file: UserFile, associated_chat_file: InMemoryChatFile
|
||||
) -> SearchDoc:
|
||||
blurb = ""
|
||||
if associated_chat_file and associated_chat_file.content:
|
||||
try:
|
||||
# Try to decode as UTF-8, but handle errors gracefully
|
||||
content_sample = associated_chat_file.content[:100]
|
||||
# Remove null bytes which can cause SQL errors
|
||||
content_sample = content_sample.replace(b"\x00", b"")
|
||||
blurb = content_sample.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
# If decoding fails completely, provide a generic description
|
||||
blurb = f"[Binary file: {db_user_file.name}]"
|
||||
|
||||
return SearchDoc(
|
||||
# Don't set ID - let SQLAlchemy auto-generate it
|
||||
document_id=db_user_file.document_id,
|
||||
chunk_ind=0, # Default to 0 for user files
|
||||
semantic_id=db_user_file.name,
|
||||
link=db_user_file.link_url,
|
||||
blurb=blurb,
|
||||
source_type=DocumentSource.FILE, # Assuming internal source for user files
|
||||
boost=0, # Default boost
|
||||
hidden=False, # Default visibility
|
||||
doc_metadata={}, # Empty metadata
|
||||
score=0.0, # Default score of 0.0 instead of None
|
||||
is_relevant=None, # No relevance initially
|
||||
relevance_explanation=None, # No explanation initially
|
||||
match_highlights=[], # No highlights initially
|
||||
updated_at=db_user_file.created_at, # Use created_at as updated_at
|
||||
primary_owners=[], # Empty list instead of None
|
||||
secondary_owners=[], # Empty list instead of None
|
||||
is_internet=False, # Not from internet
|
||||
)
|
||||
|
||||
|
||||
def translate_db_search_doc_to_server_search_doc(
|
||||
db_search_doc: SearchDoc,
|
||||
remove_doc_content: bool = False,
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.models import StatusResponse
|
||||
@@ -107,13 +106,11 @@ def get_connector_credential_pairs_for_user(
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
include_user_files: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
if eager_load_user:
|
||||
assert (
|
||||
eager_load_credential
|
||||
), "eager_load_credential must be True if eager_load_user is True"
|
||||
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
@@ -129,9 +126,6 @@ def get_connector_credential_pairs_for_user(
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
@@ -159,16 +153,14 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
db_session: Session, ids: list[int] | None = None, include_user_files: bool = False
|
||||
db_session: Session,
|
||||
ids: list[int] | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -215,15 +207,12 @@ def get_connector_credential_pair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
include_user_files: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -332,9 +321,6 @@ def _update_connector_credential_pair(
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
if status is not None:
|
||||
cc_pair.status = status
|
||||
if cc_pair.is_user_file:
|
||||
cc_pair.status = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -460,7 +446,6 @@ def add_credential_to_connector(
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
is_user_file: bool = False,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -526,7 +511,6 @@ def add_credential_to_connector(
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
is_user_file=is_user_file,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
@@ -603,29 +587,14 @@ def remove_credential_from_connector(
|
||||
|
||||
def fetch_connector_credential_pairs(
|
||||
db_session: Session,
|
||||
include_user_files: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
return db_session.query(ConnectorCredentialPair).all()
|
||||
|
||||
|
||||
def resync_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Updates state stored in the connector_credential_pair table based on the
|
||||
latest index attempt for the given search settings.
|
||||
|
||||
Args:
|
||||
cc_pair: ConnectorCredentialPair to resync
|
||||
search_settings_id: SearchSettings to use for resync
|
||||
db_session: Database session
|
||||
"""
|
||||
|
||||
def find_latest_index_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@@ -638,10 +607,11 @@ def resync_cc_pair(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
SearchSettings.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -664,23 +634,3 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_connector_credential_pairs_with_user_files(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all connector credential pairs that have associated user files.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
List of ConnectorCredentialPair objects that have user files
|
||||
"""
|
||||
return (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -43,8 +43,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ONE_HOUR_IN_SECONDS = 60 * 60
|
||||
|
||||
|
||||
def check_docs_exist(db_session: Session) -> bool:
|
||||
stmt = select(exists(DbDocument))
|
||||
@@ -609,46 +607,6 @@ def delete_documents_complete__no_commit(
|
||||
delete_documents__no_commit(db_session, document_ids)
|
||||
|
||||
|
||||
def delete_all_documents_for_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
timeout: int = ONE_HOUR_IN_SECONDS,
|
||||
) -> None:
|
||||
"""Delete all documents for a given connector credential pair.
|
||||
This will delete all documents and their associated data (chunks, feedback, tags, etc.)
|
||||
|
||||
NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during
|
||||
an index swap. If we wanted to make this more efficient, we could use a single delete
|
||||
statement + cascade.
|
||||
"""
|
||||
batch_size = 1000
|
||||
start_time = time.monotonic()
|
||||
|
||||
while True:
|
||||
# Get document IDs in batches
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair.id)
|
||||
.where(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
document_ids = db_session.scalars(stmt).all()
|
||||
|
||||
if not document_ids:
|
||||
break
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session, document_ids=list(document_ids)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if time.monotonic() - start_time > timeout:
|
||||
raise RuntimeError("Timeout reached while deleting documents")
|
||||
|
||||
|
||||
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
|
||||
"""Acquire locks for the specified documents. Ideally this shouldn't be
|
||||
called with large list of document_ids (an exception could be made if the
|
||||
|
||||
@@ -605,6 +605,7 @@ def fetch_document_sets_for_document(
|
||||
result = fetch_document_sets_for_documents([document_id], db_session)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
return result[0][1]
|
||||
|
||||
|
||||
|
||||
@@ -217,6 +217,7 @@ def mark_attempt_in_progress(
|
||||
"index_attempt_id": index_attempt.id,
|
||||
"status": IndexingStatus.IN_PROGRESS.value,
|
||||
"cc_pair_id": index_attempt.connector_credential_pair_id,
|
||||
"search_settings_id": index_attempt.search_settings_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -245,6 +246,9 @@ def mark_attempt_succeeded(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.SUCCESS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -273,6 +277,9 @@ def mark_attempt_partially_succeeded(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -305,6 +312,10 @@ def mark_attempt_canceled(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.CANCELED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -339,6 +350,10 @@ def mark_attempt_failed(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": failure_reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -710,25 +725,6 @@ def cancel_indexing_attempts_past_model(
|
||||
)
|
||||
|
||||
|
||||
def cancel_indexing_attempts_for_search_settings(
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Stops all indexing attempts that are in progress or not started for
|
||||
the specified search settings."""
|
||||
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
|
||||
),
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.values(status=IndexingStatus.FAILED)
|
||||
)
|
||||
|
||||
|
||||
def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
search_settings_id: int | None,
|
||||
db_session: Session,
|
||||
|
||||
@@ -212,10 +212,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
back_populates="creator",
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder", back_populates="user"
|
||||
)
|
||||
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
@@ -423,7 +419,6 @@ class ConnectorCredentialPair(Base):
|
||||
"""
|
||||
|
||||
__tablename__ = "connector_credential_pair"
|
||||
is_user_file: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True`
|
||||
# due to some SQLAlchemy quirks + this not being a primary key column
|
||||
id: Mapped[int] = mapped_column(
|
||||
@@ -510,10 +505,6 @@ class ConnectorCredentialPair(Base):
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
user_file: Mapped["UserFile"] = relationship(
|
||||
"UserFile", back_populates="cc_pair", uselist=False
|
||||
)
|
||||
|
||||
background_errors: Mapped[list["BackgroundError"]] = relationship(
|
||||
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -703,11 +694,7 @@ class Connector(Base):
|
||||
)
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="connector",
|
||||
passive_deletes=True,
|
||||
)
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
|
||||
# synchronize this validation logic with RefreshFrequencySchema etc on front end
|
||||
# until we have a centralized validation schema
|
||||
@@ -761,11 +748,7 @@ class Credential(Base):
|
||||
)
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="credential",
|
||||
passive_deletes=True,
|
||||
)
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
@@ -808,15 +791,6 @@ class SearchSettings(Base):
|
||||
# Mini and Large Chunks (large chunk also checks for model max context)
|
||||
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Contextual RAG
|
||||
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Contextual RAG LLM
|
||||
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
|
||||
multilingual_expansion: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), default=[]
|
||||
)
|
||||
@@ -1118,10 +1092,10 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True)
|
||||
# TODO: transition this to use the ConnectorCredentialPair id directly
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id", ondelete="CASCADE"), primary_key=True
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
credential_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("credential.id", ondelete="CASCADE"), primary_key=True
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
|
||||
# used to better keep track of document counts at a connector level
|
||||
@@ -1131,10 +1105,10 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="documents_by_connector", passive_deletes=True
|
||||
"Connector", back_populates="documents_by_connector"
|
||||
)
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="documents_by_credential", passive_deletes=True
|
||||
"Credential", back_populates="documents_by_credential"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
@@ -1658,8 +1632,8 @@ class Prompt(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
system_prompt: Mapped[str] = mapped_column(String(length=8000))
|
||||
task_prompt: Mapped[str] = mapped_column(String(length=8000))
|
||||
system_prompt: Mapped[str] = mapped_column(Text)
|
||||
task_prompt: Mapped[str] = mapped_column(Text)
|
||||
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# Default prompts are configured via backend during deployment
|
||||
@@ -1825,17 +1799,6 @@ class Persona(Base):
|
||||
secondary="persona__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
# Relationship to UserFile
|
||||
user_files: Mapped[list["UserFile"]] = relationship(
|
||||
"UserFile",
|
||||
secondary="persona__user_file",
|
||||
back_populates="assistants",
|
||||
)
|
||||
user_folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder",
|
||||
secondary="persona__user_folder",
|
||||
back_populates="assistants",
|
||||
)
|
||||
labels: Mapped[list["PersonaLabel"]] = relationship(
|
||||
"PersonaLabel",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
@@ -1852,24 +1815,6 @@ class Persona(Base):
|
||||
)
|
||||
|
||||
|
||||
class Persona__UserFolder(Base):
|
||||
__tablename__ = "persona__user_folder"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_folder_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_folder.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class Persona__UserFile(Base):
|
||||
__tablename__ = "persona__user_file"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_file_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_file.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class PersonaLabel(Base):
|
||||
__tablename__ = "persona_label"
|
||||
|
||||
@@ -2392,64 +2337,6 @@ class InputPrompt__User(Base):
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class UserFolder(Base):
|
||||
__tablename__ = "user_folder"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
description: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="folders")
|
||||
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
|
||||
assistants: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__UserFolder.__table__,
|
||||
back_populates="user_folders",
|
||||
)
|
||||
|
||||
|
||||
class UserDocument(str, Enum):
|
||||
CHAT = "chat"
|
||||
RECENT = "recent"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class UserFile(Base):
|
||||
__tablename__ = "user_file"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
assistants: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__UserFile.__table__,
|
||||
back_populates="user_files",
|
||||
)
|
||||
folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="files")
|
||||
folder: Mapped["UserFolder"] = relationship(back_populates="files")
|
||||
token_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=True, unique=True
|
||||
)
|
||||
cc_pair: Mapped["ConnectorCredentialPair"] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="user_file"
|
||||
)
|
||||
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
|
||||
"""
|
||||
Multi-tenancy related tables
|
||||
"""
|
||||
|
||||
@@ -33,12 +33,10 @@ from onyx.db.models import StarterMessage
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -201,7 +199,7 @@ def create_update_persona(
|
||||
create_persona_request: PersonaUpsertRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> FullPersonaSnapshot:
|
||||
) -> PersonaSnapshot:
|
||||
"""Higher level function than upsert_persona, although either is valid to use."""
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
@@ -211,6 +209,7 @@ def create_update_persona(
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
is_default_persona: bool | None = create_persona_request.is_default_persona
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
@@ -222,7 +221,7 @@ def create_update_persona(
|
||||
user.role == UserRole.CURATOR
|
||||
or user.role == UserRole.GLOBAL_CURATOR
|
||||
):
|
||||
pass
|
||||
is_default_persona = None
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
@@ -250,9 +249,7 @@ def create_update_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=create_persona_request.user_file_ids,
|
||||
user_folder_ids=create_persona_request.user_folder_ids,
|
||||
is_default_persona=is_default_persona,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -271,7 +268,7 @@ def create_update_persona(
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def update_persona_shared_users(
|
||||
@@ -347,8 +344,6 @@ def get_personas_for_user(
|
||||
selectinload(Persona.groups),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.user_folders),
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
@@ -443,8 +438,6 @@ def upsert_persona(
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[int] | None = None,
|
||||
user_folder_ids: list[int] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
@@ -470,7 +463,6 @@ def upsert_persona(
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
# Fetch and attach tools by IDs
|
||||
tools = None
|
||||
if tool_ids is not None:
|
||||
@@ -489,26 +481,6 @@ def upsert_persona(
|
||||
if not document_sets and document_set_ids:
|
||||
raise ValueError("document_sets not found")
|
||||
|
||||
# Fetch and attach user_files by IDs
|
||||
user_files = None
|
||||
if user_file_ids is not None:
|
||||
user_files = (
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
)
|
||||
if not user_files and user_file_ids:
|
||||
raise ValueError("user_files not found")
|
||||
|
||||
# Fetch and attach user_folders by IDs
|
||||
user_folders = None
|
||||
if user_folder_ids is not None:
|
||||
user_folders = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id.in_(user_folder_ids))
|
||||
.all()
|
||||
)
|
||||
if not user_folders and user_folder_ids:
|
||||
raise ValueError("user_folders not found")
|
||||
|
||||
# Fetch and attach prompts by IDs
|
||||
prompts = None
|
||||
if prompt_ids is not None:
|
||||
@@ -577,14 +549,6 @@ def upsert_persona(
|
||||
if tools is not None:
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
|
||||
if user_folder_ids is not None:
|
||||
existing_persona.user_folders.clear()
|
||||
existing_persona.user_folders = user_folders or []
|
||||
|
||||
# We should only update display priority if it is not already set
|
||||
if existing_persona.display_priority is None:
|
||||
existing_persona.display_priority = display_priority
|
||||
@@ -626,8 +590,6 @@ def upsert_persona(
|
||||
is_default_persona=is_default_persona
|
||||
if is_default_persona is not None
|
||||
else False,
|
||||
user_folders=user_folders or [],
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
|
||||
@@ -62,9 +62,6 @@ def create_search_settings(
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
@@ -322,7 +319,6 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
@@ -337,6 +333,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.document import delete_all_documents_for_connector_credential_pair
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import (
|
||||
count_unique_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
@@ -27,50 +26,32 @@ def _perform_index_swap(
|
||||
current_search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings,
|
||||
all_cc_pairs: list[ConnectorCredentialPair],
|
||||
cleanup_documents: bool = False,
|
||||
) -> None:
|
||||
"""Swap the indices and expire the old one."""
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_for_search_settings(
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
# sync based on the new search settings
|
||||
search_settings_id=secondary_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cleanup_documents:
|
||||
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
|
||||
# doing an instant swap and no documents will exist in the new index.
|
||||
for cc_pair in all_cc_pairs:
|
||||
delete_all_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# swap over search settings
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# remove the old index from the vector db
|
||||
document_index = get_default_document_index(secondary_search_settings, None)
|
||||
document_index.ensure_indices_exist(
|
||||
@@ -107,9 +88,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
|
||||
# doing an instant swap.
|
||||
cleanup_documents=True,
|
||||
)
|
||||
return current_search_settings
|
||||
|
||||
|
||||
@@ -1,466 +0,0 @@
|
||||
import datetime
|
||||
import time
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import get_current_tenant_id
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.server.documents.connector import trigger_indexing_for_cc_pair
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.models import StatusResponse
|
||||
|
||||
USER_FILE_CONSTANT = "USER_FILE_CONNECTOR"
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
link_url: str | None = None,
|
||||
) -> list[UserFile]:
|
||||
upload_response = upload_files(files, db_session)
|
||||
user_files = []
|
||||
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
new_file = UserFile(
|
||||
user_id=user.id if user else None,
|
||||
folder_id=folder_id,
|
||||
file_id=file_path,
|
||||
document_id="USER_FILE_CONNECTOR__" + file_path,
|
||||
name=file.filename,
|
||||
token_count=None,
|
||||
link_url=link_url,
|
||||
)
|
||||
db_session.add(new_file)
|
||||
user_files.append(new_file)
|
||||
db_session.commit()
|
||||
return user_files
|
||||
|
||||
|
||||
def create_user_file_with_indexing(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
trigger_index: bool = True,
|
||||
) -> list[UserFile]:
|
||||
"""Create user files and trigger immediate indexing"""
|
||||
# Create the user files first
|
||||
user_files = create_user_files(files, folder_id, user, db_session)
|
||||
|
||||
# Create connector and credential for each file
|
||||
for user_file in user_files:
|
||||
cc_pair = create_file_connector_credential(user_file, user, db_session)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Trigger immediate high-priority indexing for all created files
|
||||
if trigger_index:
|
||||
tenant_id = get_current_tenant_id()
|
||||
for user_file in user_files:
|
||||
# Use the existing trigger_indexing_for_cc_pair function but with highest priority
|
||||
if user_file.cc_pair_id:
|
||||
trigger_indexing_for_cc_pair(
|
||||
[],
|
||||
user_file.cc_pair.connector_id,
|
||||
False,
|
||||
tenant_id,
|
||||
db_session,
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
return user_files
|
||||
|
||||
|
||||
def create_file_connector_credential(
|
||||
user_file: UserFile, user: User, db_session: Session
|
||||
) -> StatusResponse:
|
||||
"""Create connector and credential for a user file"""
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{user_file.file_id}-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
connector = create_connector(db_session=db_session, connector_data=connector_base)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{user_file.file_id}-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
return add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{user_file.file_id}-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
is_user_file=True,
|
||||
)
|
||||
|
||||
|
||||
def get_user_file_indexing_status(
|
||||
file_ids: list[int], db_session: Session
|
||||
) -> dict[int, bool]:
|
||||
"""Get indexing status for multiple user files"""
|
||||
status_dict = {}
|
||||
|
||||
# Query UserFile with cc_pair join
|
||||
files_with_pairs = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id.in_(file_ids))
|
||||
.options(joinedload(UserFile.cc_pair))
|
||||
.all()
|
||||
)
|
||||
|
||||
for file in files_with_pairs:
|
||||
if file.cc_pair and file.cc_pair.last_successful_index_time:
|
||||
status_dict[file.id] = True
|
||||
else:
|
||||
status_dict[file.id] = False
|
||||
|
||||
return status_dict
|
||||
|
||||
|
||||
def calculate_user_files_token_count(
|
||||
file_ids: list[int], folder_ids: list[int], db_session: Session
|
||||
) -> int:
|
||||
"""Calculate total token count for specified files and folders"""
|
||||
total_tokens = 0
|
||||
|
||||
# Get tokens from individual files
|
||||
if file_ids:
|
||||
file_tokens = (
|
||||
db_session.query(func.sum(UserFile.token_count))
|
||||
.filter(UserFile.id.in_(file_ids))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
total_tokens += file_tokens
|
||||
|
||||
# Get tokens from folders
|
||||
if folder_ids:
|
||||
folder_files_tokens = (
|
||||
db_session.query(func.sum(UserFile.token_count))
|
||||
.filter(UserFile.folder_id.in_(folder_ids))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
total_tokens += folder_files_tokens
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
def load_all_user_files(
|
||||
file_ids: list[int], folder_ids: list[int], db_session: Session
|
||||
) -> list[UserFile]:
|
||||
"""Load all user files from specified file IDs and folder IDs"""
|
||||
result = []
|
||||
|
||||
# Get individual files
|
||||
if file_ids:
|
||||
files = db_session.query(UserFile).filter(UserFile.id.in_(file_ids)).all()
|
||||
result.extend(files)
|
||||
|
||||
# Get files from folders
|
||||
if folder_ids:
|
||||
folder_files = (
|
||||
db_session.query(UserFile).filter(UserFile.folder_id.in_(folder_ids)).all()
|
||||
)
|
||||
result.extend(folder_files)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_user_files_from_folder(folder_id: int, db_session: Session) -> list[UserFile]:
|
||||
return db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
|
||||
|
||||
|
||||
def share_file_with_assistant(
|
||||
file_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
|
||||
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
|
||||
|
||||
if file and assistant:
|
||||
file.assistants.append(assistant)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def unshare_file_with_assistant(
|
||||
file_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
db_session.query(Persona__UserFile).filter(
|
||||
and_(
|
||||
Persona__UserFile.user_file_id == file_id,
|
||||
Persona__UserFile.persona_id == assistant_id,
|
||||
)
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def share_folder_with_assistant(
|
||||
folder_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
|
||||
assistant = db_session.query(Persona).filter(Persona.id == assistant_id).first()
|
||||
|
||||
if folder and assistant:
|
||||
for file in folder.files:
|
||||
share_file_with_assistant(file.id, assistant_id, db_session)
|
||||
|
||||
|
||||
def unshare_folder_with_assistant(
|
||||
folder_id: int, assistant_id: int, db_session: Session
|
||||
) -> None:
|
||||
folder = db_session.query(UserFolder).filter(UserFolder.id == folder_id).first()
|
||||
|
||||
if folder:
|
||||
for file in folder.files:
|
||||
unshare_file_with_assistant(file.id, assistant_id, db_session)
|
||||
|
||||
|
||||
def fetch_user_files_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, int | None]:
|
||||
"""
|
||||
Fetches user file IDs for the given document IDs.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to fetch user files for
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to user file IDs (or None if no user file exists)
|
||||
"""
|
||||
# First, get the document to cc_pair mapping
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id, ConnectorCredentialPair.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.filter(Document.id.in_(document_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get cc_pair to user_file mapping
|
||||
cc_pair_to_user_file = (
|
||||
db_session.query(ConnectorCredentialPair.id, UserFile.id)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from cc_pair_id to user_file_id
|
||||
cc_pair_to_user_file_dict = {
|
||||
cc_pair_id: user_file_id for cc_pair_id, user_file_id in cc_pair_to_user_file
|
||||
}
|
||||
|
||||
# Create the final result mapping document_id to user_file_id
|
||||
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
for doc_id, cc_pair_id in doc_cc_pairs:
|
||||
if cc_pair_id in cc_pair_to_user_file_dict:
|
||||
result[doc_id] = cc_pair_to_user_file_dict[cc_pair_id]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_user_folders_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, int | None]:
|
||||
"""
|
||||
Fetches user folder IDs for the given document IDs.
|
||||
|
||||
For each document, returns the folder ID that the document's associated user file belongs to.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to fetch user folders for
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping document IDs to user folder IDs (or None if no user folder exists)
|
||||
"""
|
||||
# First, get the document to cc_pair mapping
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id, ConnectorCredentialPair.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.filter(Document.id.in_(document_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get cc_pair to user_file and folder mapping
|
||||
cc_pair_to_folder = (
|
||||
db_session.query(ConnectorCredentialPair.id, UserFile.folder_id)
|
||||
.join(UserFile, UserFile.cc_pair_id == ConnectorCredentialPair.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
[cc_pair_id for _, cc_pair_id in doc_cc_pairs]
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from cc_pair_id to folder_id
|
||||
cc_pair_to_folder_dict = {
|
||||
cc_pair_id: folder_id for cc_pair_id, folder_id in cc_pair_to_folder
|
||||
}
|
||||
|
||||
# Create the final result mapping document_id to folder_id
|
||||
result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
for doc_id, cc_pair_id in doc_cc_pairs:
|
||||
if cc_pair_id in cc_pair_to_folder_dict:
|
||||
result[doc_id] = cc_pair_to_folder_dict[cc_pair_id]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_user_file_from_id(db_session: Session, user_file_id: int) -> UserFile | None:
|
||||
return db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
||||
|
||||
|
||||
# def fetch_user_files_for_documents(
|
||||
# # document_ids: list[str],
|
||||
# # db_session: Session,
|
||||
# # ) -> dict[str, int | None]:
|
||||
# # # Query UserFile objects for the given document_ids
|
||||
# # user_files = (
|
||||
# # db_session.query(UserFile).filter(UserFile.document_id.in_(document_ids)).all()
|
||||
# # )
|
||||
|
||||
# # # Create a dictionary mapping document_ids to UserFile objects
|
||||
# # result: dict[str, int | None] = {doc_id: None for doc_id in document_ids}
|
||||
# # for user_file in user_files:
|
||||
# # result[user_file.document_id] = user_file.id
|
||||
|
||||
# # return result
|
||||
|
||||
|
||||
def upsert_user_folder(
|
||||
db_session: Session,
|
||||
id: int | None = None,
|
||||
user_id: UUID | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
created_at: datetime.datetime | None = None,
|
||||
user: User | None = None,
|
||||
files: list[UserFile] | None = None,
|
||||
assistants: list[Persona] | None = None,
|
||||
) -> UserFolder:
|
||||
if id is not None:
|
||||
user_folder = db_session.query(UserFolder).filter_by(id=id).first()
|
||||
else:
|
||||
user_folder = (
|
||||
db_session.query(UserFolder).filter_by(name=name, user_id=user_id).first()
|
||||
)
|
||||
|
||||
if user_folder:
|
||||
if user_id is not None:
|
||||
user_folder.user_id = user_id
|
||||
if name is not None:
|
||||
user_folder.name = name
|
||||
if description is not None:
|
||||
user_folder.description = description
|
||||
if created_at is not None:
|
||||
user_folder.created_at = created_at
|
||||
if user is not None:
|
||||
user_folder.user = user
|
||||
if files is not None:
|
||||
user_folder.files = files
|
||||
if assistants is not None:
|
||||
user_folder.assistants = assistants
|
||||
else:
|
||||
user_folder = UserFolder(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
created_at=created_at or datetime.datetime.utcnow(),
|
||||
user=user,
|
||||
files=files or [],
|
||||
assistants=assistants or [],
|
||||
)
|
||||
db_session.add(user_folder)
|
||||
|
||||
db_session.flush()
|
||||
return user_folder
|
||||
|
||||
|
||||
def get_user_folder_by_name(db_session: Session, name: str) -> UserFolder | None:
|
||||
return db_session.query(UserFolder).filter(UserFolder.name == name).first()
|
||||
|
||||
|
||||
def update_user_file_token_count__no_commit(
|
||||
user_file_id_to_token_count: dict[int, int | None],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for user_file_id, token_count in user_file_id_to_token_count.items():
|
||||
db_session.query(UserFile).filter(UserFile.id == user_file_id).update(
|
||||
{UserFile.token_count: token_count}
|
||||
)
|
||||
@@ -104,16 +104,6 @@ class VespaDocumentFields:
|
||||
aggregated_chunk_boost_factor: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VespaDocumentUserFields:
|
||||
"""
|
||||
Fields that are specific to the user who is indexing the document.
|
||||
"""
|
||||
|
||||
user_file_id: str | None = None
|
||||
user_folder_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateRequest:
|
||||
"""
|
||||
@@ -268,8 +258,7 @@ class Updatable(abc.ABC):
|
||||
*,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""
|
||||
Updates all chunks for a document with the specified fields.
|
||||
|
||||
@@ -98,12 +98,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
field metadata type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field chunk_context type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field doc_summary type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field metadata_suffix type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
@@ -120,22 +114,12 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
field document_sets type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field user_file type int {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field user_folder type int {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -24,11 +24,9 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import CONTENT
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
@@ -128,8 +126,7 @@ def _vespa_hit_to_inference_chunk(
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=fields[CHUNK_ID],
|
||||
blurb=fields.get(BLURB, ""), # Unused
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
|
||||
# also sometimes context for contextual rag
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
||||
source_links=source_links_dict or {0: ""},
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
@@ -146,8 +143,6 @@ def _vespa_hit_to_inference_chunk(
|
||||
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
|
||||
metadata=metadata,
|
||||
metadata_suffix=fields.get(METADATA_SUFFIX),
|
||||
doc_summary=fields.get(DOC_SUMMARY, ""),
|
||||
chunk_context=fields.get(CHUNK_CONTEXT, ""),
|
||||
match_highlights=match_highlights,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
@@ -350,19 +345,6 @@ def query_vespa(
|
||||
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
||||
|
||||
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
||||
|
||||
try:
|
||||
num_retrieved_inference_chunks = len(inference_chunks)
|
||||
num_retrieved_document_ids = len(
|
||||
set([chunk.document_id for chunk in inference_chunks])
|
||||
)
|
||||
logger.debug(
|
||||
f"Retrieved {num_retrieved_inference_chunks} inference chunks for {num_retrieved_document_ids} documents"
|
||||
)
|
||||
except Exception as e:
|
||||
# Debug logging only, should not fail the retrieval
|
||||
logger.error(f"Error logging retrieval statistics: {e}")
|
||||
|
||||
# Good Debugging Spot
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import UpdateRequest
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
|
||||
from onyx.document_index.vespa.chunk_retrieval import (
|
||||
parallel_visit_api_retrieval,
|
||||
@@ -71,8 +70,6 @@ from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
|
||||
from onyx.document_index.vespa_constants import TENANT_ID_PAT
|
||||
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
@@ -190,7 +187,7 @@ class VespaIndex(DocumentIndex):
|
||||
) -> None:
|
||||
if MULTI_TENANT:
|
||||
logger.info(
|
||||
"Skipping Vespa index setup for multitenant (would wipe all indices)"
|
||||
"Skipping Vespa index seup for multitenant (would wipe all indices)"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -595,8 +592,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_chunk_id: UUID,
|
||||
index_name: str,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
fields: VespaDocumentFields,
|
||||
doc_id: str,
|
||||
http_client: httpx.Client,
|
||||
) -> None:
|
||||
@@ -607,31 +603,21 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
|
||||
if fields is not None:
|
||||
if fields.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": fields.boost}
|
||||
if fields.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": fields.boost}
|
||||
|
||||
if fields.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {document_set: 1 for document_set in fields.document_sets}
|
||||
}
|
||||
if fields.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {document_set: 1 for document_set in fields.document_sets}
|
||||
}
|
||||
|
||||
if fields.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
|
||||
}
|
||||
if fields.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
|
||||
}
|
||||
|
||||
if fields.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
|
||||
|
||||
if user_fields is not None:
|
||||
if user_fields.user_file_id is not None:
|
||||
update_dict["fields"][USER_FILE] = {"assign": user_fields.user_file_id}
|
||||
|
||||
if user_fields.user_folder_id is not None:
|
||||
update_dict["fields"][USER_FOLDER] = {
|
||||
"assign": user_fields.user_folder_id
|
||||
}
|
||||
if fields.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": fields.hidden}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update.")
|
||||
@@ -663,8 +649,7 @@ class VespaIndex(DocumentIndex):
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
@@ -697,12 +682,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for doc_chunk_id in doc_chunk_ids:
|
||||
self._update_single_chunk(
|
||||
doc_chunk_id,
|
||||
index_name,
|
||||
fields,
|
||||
user_fields,
|
||||
doc_id,
|
||||
httpx_client,
|
||||
doc_chunk_id, index_name, fields, doc_id, httpx_client
|
||||
)
|
||||
|
||||
return doc_chunk_count
|
||||
@@ -743,7 +723,6 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=large_chunks_enabled,
|
||||
)
|
||||
|
||||
for doc_chunk_ids_batch in batch_generator(
|
||||
chunks_to_delete, BATCH_SIZE
|
||||
):
|
||||
|
||||
@@ -25,11 +25,9 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import CONTENT
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
@@ -51,8 +49,6 @@ from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import TITLE
|
||||
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -178,7 +174,7 @@ def _index_vespa_chunk(
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
@@ -193,8 +189,6 @@ def _index_vespa_chunk(
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: metadata_list,
|
||||
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
|
||||
CHUNK_CONTEXT: chunk.chunk_context,
|
||||
DOC_SUMMARY: chunk.doc_summary,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
@@ -207,8 +201,6 @@ def _index_vespa_chunk(
|
||||
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
IMAGE_FILE_NAME: chunk.image_file_name,
|
||||
USER_FILE: chunk.user_file if chunk.user_file is not None else None,
|
||||
USER_FOLDER: chunk.user_folder if chunk.user_folder is not None else None,
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -29,26 +27,14 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
if vals is None or not vals:
|
||||
if vals is None:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
valid_vals = [val for val in vals if val]
|
||||
if not key or not valid_vals:
|
||||
return ""
|
||||
|
||||
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
@@ -56,59 +42,53 @@ def build_vespa_filters(
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
# Slightly over 3 Months, approximately 1 fiscal quarter
|
||||
untimed_doc_cutoff: timedelta = timedelta(days=92),
|
||||
) -> str:
|
||||
if not cutoff:
|
||||
return ""
|
||||
|
||||
# For Documents that don't have an updated at, filter them out for queries asking for
|
||||
# very recent documents (3 months) default. Documents that don't have an updated at
|
||||
# time are assigned 3 months for time decay value
|
||||
include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
# Documents without updated_at are assigned -1 as their date
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# If running in multi-tenant mode
|
||||
# If running in multi-tenant mode, we may want to filter by tenant_id
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
# ACL filters
|
||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
tags = filters.tags
|
||||
if tags:
|
||||
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
# New: user_file_ids as integer filters
|
||||
filter_str += _build_int_or_filters(USER_FILE, filters.user_file_ids)
|
||||
|
||||
filter_str += _build_int_or_filters(USER_FOLDER, filters.user_folder_ids)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
filter_str = filter_str[:-5] # We remove the trailing " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -67,14 +67,10 @@ EMBEDDINGS = "embeddings"
|
||||
TITLE_EMBEDDING = "title_embedding"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
DOC_SUMMARY = "doc_summary"
|
||||
CHUNK_CONTEXT = "chunk_context"
|
||||
BOOST = "boost"
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
@@ -110,8 +106,6 @@ YQL_BASE = (
|
||||
f"{LARGE_CHUNK_REFERENCE_IDS}, "
|
||||
f"{METADATA}, "
|
||||
f"{METADATA_SUFFIX}, "
|
||||
f"{DOC_SUMMARY}, "
|
||||
f"{CHUNK_CONTEXT}, "
|
||||
f"{CONTENT_SUMMARY} "
|
||||
f"from {{index_name}} where "
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ def delete_unstructured_api_key() -> None:
|
||||
def _sdk_partition_request(
|
||||
file: IO[Any], file_name: str, **kwargs: Any
|
||||
) -> operations.PartitionRequest:
|
||||
file.seek(0, 0)
|
||||
try:
|
||||
request = operations.PartitionRequest(
|
||||
partition_parameters=shared.PartitionParameters(
|
||||
|
||||
@@ -31,7 +31,6 @@ class FileStore(ABC):
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
file_metadata: dict | None = None,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Save a file to the blob store
|
||||
@@ -43,8 +42,6 @@ class FileStore(ABC):
|
||||
- display_name: Display name of the file
|
||||
- file_origin: Origin of the file
|
||||
- file_type: Type of the file
|
||||
- file_metadata: Additional metadata for the file
|
||||
- commit: Whether to commit the transaction after saving the file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -93,7 +90,6 @@ class PostgresBackedFileStore(FileStore):
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
file_metadata: dict | None = None,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
try:
|
||||
# The large objects in postgres are saved as special objects can be listed with
|
||||
@@ -108,8 +104,7 @@ class PostgresBackedFileStore(FileStore):
|
||||
db_session=self.db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
if commit:
|
||||
self.db_session.commit()
|
||||
self.db_session.commit()
|
||||
except Exception:
|
||||
self.db_session.rollback()
|
||||
raise
|
||||
|
||||
@@ -14,7 +14,6 @@ class ChatFileType(str, Enum):
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
CSV = "csv"
|
||||
USER_KNOWLEDGE = "user_knowledge"
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
|
||||
@@ -10,62 +10,12 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: int) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
|
||||
|
||||
def store_user_file_plaintext(
|
||||
user_file_id: int, plaintext_content: str, db_session: Session
|
||||
) -> bool:
|
||||
"""
|
||||
Store plaintext content for a user file in the file store.
|
||||
|
||||
Args:
|
||||
user_file_id: The ID of the user file
|
||||
plaintext_content: The plaintext content to store
|
||||
db_session: The database session
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
# Store the plaintext in the file store
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
try:
|
||||
file_store.save_file(
|
||||
file_name=plaintext_file_name,
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
commit=False,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -103,83 +53,6 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def load_user_folder(folder_id: int, db_session: Session) -> list[InMemoryChatFile]:
|
||||
user_files = (
|
||||
db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
|
||||
)
|
||||
return [load_user_file(file.id, db_session) for file in user_files]
|
||||
|
||||
|
||||
def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
|
||||
user_file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
|
||||
if not user_file:
|
||||
raise ValueError(f"User file with id {file_id} not found")
|
||||
|
||||
# Try to load plaintext version first
|
||||
file_store = get_default_file_store(db_session)
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(file_id)
|
||||
|
||||
try:
|
||||
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
||||
return InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
content=file_io.read(),
|
||||
file_type=ChatFileType.USER_KNOWLEDGE,
|
||||
filename=user_file.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load plaintext file {plaintext_file_name}, defaulting to original file: {e}"
|
||||
)
|
||||
# Fall back to original file if plaintext not available
|
||||
file_io = file_store.read_file(user_file.file_id, mode="b")
|
||||
return InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
content=file_io.read(),
|
||||
file_type=ChatFileType.USER_KNOWLEDGE,
|
||||
filename=user_file.name,
|
||||
)
|
||||
|
||||
|
||||
def load_all_user_files(
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> list[InMemoryChatFile]:
|
||||
return cast(
|
||||
list[InMemoryChatFile],
|
||||
run_functions_tuples_in_parallel(
|
||||
[(load_user_file, (file_id, db_session)) for file_id in user_file_ids]
|
||||
)
|
||||
+ [
|
||||
file
|
||||
for folder_id in user_folder_ids
|
||||
for file in load_user_folder(folder_id, db_session)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def load_all_user_file_files(
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
user_files: list[UserFile] = []
|
||||
for user_file_id in user_file_ids:
|
||||
user_file = (
|
||||
db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
||||
)
|
||||
if user_file is not None:
|
||||
user_files.append(user_file)
|
||||
for user_folder_id in user_folder_ids:
|
||||
user_files.extend(
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.folder_id == user_folder_id)
|
||||
.all()
|
||||
)
|
||||
return user_files
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
@@ -198,7 +71,6 @@ def save_file_from_url(url: str) -> str:
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
commit=True,
|
||||
)
|
||||
return unique_id
|
||||
|
||||
@@ -213,7 +85,6 @@ def save_file_from_base64(base64_string: str) -> str:
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
commit=True,
|
||||
)
|
||||
return unique_id
|
||||
|
||||
@@ -257,39 +128,3 @@ def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
|
||||
def load_all_persona_files_for_chat(
|
||||
persona_id: int, db_session: Session
|
||||
) -> tuple[list[InMemoryChatFile], list[int]]:
|
||||
from onyx.db.models import Persona
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == persona_id)
|
||||
.options(
|
||||
joinedload(Persona.user_files),
|
||||
joinedload(Persona.user_folders).joinedload(UserFolder.files),
|
||||
)
|
||||
.one()
|
||||
)
|
||||
|
||||
persona_file_calls = [
|
||||
(load_user_file, (user_file.id, db_session)) for user_file in persona.user_files
|
||||
]
|
||||
persona_loaded_files = run_functions_tuples_in_parallel(persona_file_calls)
|
||||
|
||||
persona_folder_files = []
|
||||
persona_folder_file_ids = []
|
||||
for user_folder in persona.user_folders:
|
||||
folder_files = load_user_folder(user_folder.id, db_session)
|
||||
persona_folder_files.extend(folder_files)
|
||||
persona_folder_file_ids.extend([file.id for file in user_folder.files])
|
||||
|
||||
persona_files = list(persona_loaded_files) + persona_folder_files
|
||||
persona_file_ids = [
|
||||
file.id for file in persona.user_files
|
||||
] + persona_folder_file_ids
|
||||
|
||||
return persona_files, persona_file_ids
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
@@ -16,7 +13,6 @@ from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
@@ -86,9 +82,6 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
|
||||
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=large_chunk_id,
|
||||
chunk_context="",
|
||||
doc_summary="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
offset = 0
|
||||
@@ -127,7 +120,6 @@ class Chunker:
|
||||
tokenizer: BaseTokenizer,
|
||||
enable_multipass: bool = False,
|
||||
enable_large_chunks: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
@@ -141,20 +133,9 @@ class Chunker:
|
||||
self.chunk_token_limit = chunk_token_limit
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.enable_contextual_rag = enable_contextual_rag
|
||||
if enable_contextual_rag:
|
||||
assert (
|
||||
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
|
||||
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
|
||||
self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * (
|
||||
int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY)
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=blurb_size,
|
||||
@@ -240,9 +221,6 @@ class Chunker:
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
@@ -310,7 +288,7 @@ class Chunker:
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
section_token_count = len(self.tokenizer.tokenize(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
@@ -333,7 +311,8 @@ class Chunker:
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
and len(self.tokenizer.tokenize(split_text))
|
||||
> content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
@@ -363,10 +342,10 @@ class Chunker:
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_token_count = len(self.tokenizer.tokenize(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
@@ -414,7 +393,7 @@ class Chunker:
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
title_tokens = len(self.tokenizer.tokenize(title_prefix))
|
||||
|
||||
# Metadata prep
|
||||
metadata_suffix_semantic = ""
|
||||
@@ -427,50 +406,15 @@ class Chunker:
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
|
||||
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
# If metadata is too large, skip it in the semantic content
|
||||
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
single_chunk_fits = True
|
||||
doc_token_count = 0
|
||||
if self.enable_contextual_rag:
|
||||
doc_content = document.get_text_content()
|
||||
tokenized_doc = self.tokenizer.tokenize(doc_content)
|
||||
doc_token_count = len(tokenized_doc)
|
||||
|
||||
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
|
||||
single_chunk_fits = (
|
||||
doc_token_count + title_tokens + metadata_tokens
|
||||
<= self.chunk_token_limit
|
||||
)
|
||||
|
||||
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
|
||||
context_size = 0
|
||||
if (
|
||||
self.enable_contextual_rag
|
||||
and not single_chunk_fits
|
||||
and not AVERAGE_SUMMARY_EMBEDDINGS
|
||||
):
|
||||
context_size += self.default_contextual_rag_reserved_tokens
|
||||
|
||||
# Adjust content token limit to accommodate title + metadata
|
||||
content_token_limit = (
|
||||
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
|
||||
)
|
||||
|
||||
# first check: if there is not enough actual chunk content when including contextual rag,
|
||||
# then don't do contextual rag
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
context_size = 0 # Don't do contextual RAG
|
||||
# revert to previous content token limit
|
||||
content_token_limit = (
|
||||
self.chunk_token_limit - title_tokens - metadata_tokens
|
||||
)
|
||||
|
||||
# If there is not enough context remaining then just index the chunk with no prefix/suffix
|
||||
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
# Not enough space left, so revert to full chunk without the prefix
|
||||
content_token_limit = self.chunk_token_limit
|
||||
@@ -494,9 +438,6 @@ class Chunker:
|
||||
large_chunks = generate_large_chunks(normal_chunks)
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
for chunk in normal_chunks:
|
||||
chunk.contextual_rag_reserved_tokens = context_size
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:
|
||||
|
||||
@@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
if chunk.large_chunk_reference_ids:
|
||||
large_chunks_present = True
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
if not chunk_text:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
@@ -9,13 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
@@ -43,15 +36,11 @@ from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.db.pg_file_store import read_lobj
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.db.user_documents import fetch_user_files_for_documents
|
||||
from onyx.db.user_documents import fetch_user_folders_for_documents
|
||||
from onyx.db.user_documents import update_user_file_token_count__no_commit
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
@@ -59,7 +48,6 @@ from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
@@ -69,25 +57,11 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
@@ -275,8 +249,6 @@ def index_doc_batch_with_handler(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -289,8 +261,6 @@ def index_doc_batch_with_handler(
|
||||
db_session=db_session,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
tenant_id=tenant_id,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
)
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
@@ -561,145 +531,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
return indexed_documents
|
||||
|
||||
|
||||
def add_document_summaries(
|
||||
chunks_by_doc: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
trunc_doc_tokens: int,
|
||||
) -> list[int] | None:
|
||||
"""
|
||||
Adds a document summary to a list of chunks from the same document.
|
||||
Returns the number of tokens in the document.
|
||||
"""
|
||||
|
||||
doc_tokens = []
|
||||
# this is value is the same for each chunk in the document; 0 indicates
|
||||
# There is not enough space for contextual RAG (the chunk content
|
||||
# and possibly metadata took up too much space)
|
||||
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
|
||||
return None
|
||||
|
||||
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
|
||||
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
|
||||
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
|
||||
doc_summary = message_to_string(
|
||||
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
|
||||
)
|
||||
|
||||
for chunk in chunks_by_doc:
|
||||
chunk.doc_summary = doc_summary
|
||||
|
||||
return doc_tokens
|
||||
|
||||
|
||||
def add_chunk_summaries(
|
||||
chunks_by_doc: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
trunc_doc_chunk_tokens: int,
|
||||
doc_tokens: list[int] | None,
|
||||
) -> None:
|
||||
"""
|
||||
Adds chunk summaries to the chunks grouped by document id.
|
||||
Chunk summaries look at the chunk as well as the entire document (or a summary,
|
||||
if the document is too long) and describe how the chunk relates to the document.
|
||||
"""
|
||||
# all chunks within a document have the same contextual_rag_reserved_tokens
|
||||
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
|
||||
return
|
||||
|
||||
# use values computed in above doc summary section if available
|
||||
doc_tokens = doc_tokens or tokenizer.encode(
|
||||
chunks_by_doc[0].source_document.get_text_content()
|
||||
)
|
||||
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer)
|
||||
|
||||
# only compute doc summary if needed
|
||||
doc_info = (
|
||||
doc_content
|
||||
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
else chunks_by_doc[0].doc_summary
|
||||
)
|
||||
if not doc_info:
|
||||
# This happens if the document is too long AND document summaries are turned off
|
||||
# In this case we compute a doc summary using the LLM
|
||||
doc_info = message_to_string(
|
||||
llm.invoke(
|
||||
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
)
|
||||
|
||||
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
|
||||
|
||||
def assign_context(chunk: DocAwareChunk) -> None:
|
||||
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
|
||||
try:
|
||||
chunk.chunk_context = message_to_string(
|
||||
llm.invoke(
|
||||
context_prompt1 + context_prompt2,
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
)
|
||||
except LLMRateLimitError as e:
|
||||
# Erroring during chunker is undesirable, so we log the error and continue
|
||||
# TODO: for v2, add robust retry logic
|
||||
logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e)
|
||||
chunk.chunk_context = ""
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding chunk summary: {e}", exc_info=e)
|
||||
chunk.chunk_context = ""
|
||||
|
||||
run_functions_tuples_in_parallel(
|
||||
[(assign_context, (chunk,)) for chunk in chunks_by_doc]
|
||||
)
|
||||
|
||||
|
||||
def add_contextual_summaries(
|
||||
chunks: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Adds Document summary and chunk-within-document context to the chunks
|
||||
based on which environment variables are set.
|
||||
"""
|
||||
max_context = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
output_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
doc2chunks = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc2chunks[chunk.source_document.id].append(chunk)
|
||||
|
||||
# The number of tokens allowed for the document when computing a document summary
|
||||
trunc_doc_summary_tokens = max_context - len(
|
||||
tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
|
||||
)
|
||||
|
||||
prompt_tokens = len(
|
||||
tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
|
||||
)
|
||||
# The number of tokens allowed for the document when computing a
|
||||
# "chunk in context of document" summary
|
||||
trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit
|
||||
for chunks_by_doc in doc2chunks.values():
|
||||
doc_tokens = None
|
||||
if USE_DOCUMENT_SUMMARY:
|
||||
doc_tokens = add_document_summaries(
|
||||
chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens
|
||||
)
|
||||
|
||||
if USE_CHUNK_SUMMARY:
|
||||
add_chunk_summaries(
|
||||
chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -711,8 +542,6 @@ def index_doc_batch(
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
@@ -774,21 +603,6 @@ def index_doc_batch(
|
||||
# NOTE: no special handling for failures here, since the chunker is not
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
|
||||
llm_tokenizer: BaseTokenizer | None = None
|
||||
|
||||
# contextual RAG
|
||||
if enable_contextual_rag:
|
||||
assert llm is not None, "must provide an LLM for contextual RAG"
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Because the chunker's tokens are different from the LLM's tokens,
|
||||
# We add a fudge factor to ensure we truncate prompts to the LLM's token limit
|
||||
chunks = add_contextual_summaries(
|
||||
chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
@@ -832,15 +646,6 @@ def index_doc_batch(
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_user_file_id: dict[str, int | None] = fetch_user_files_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
doc_id_to_user_folder_id: dict[
|
||||
str, int | None
|
||||
] = fetch_user_folders_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
@@ -860,48 +665,6 @@ def index_doc_batch(
|
||||
for document_id in updatable_ids
|
||||
}
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tokenizer: {e}")
|
||||
llm_tokenizer = None
|
||||
|
||||
# Calculate token counts for each document by combining all its chunks' content
|
||||
user_file_id_to_token_count: dict[int, int | None] = {}
|
||||
user_file_id_to_raw_text: dict[int, str] = {}
|
||||
for document_id in updatable_ids:
|
||||
# Only calculate token counts for documents that have a user file ID
|
||||
if (
|
||||
document_id in doc_id_to_user_file_id
|
||||
and doc_id_to_user_file_id[document_id] is not None
|
||||
):
|
||||
user_file_id = doc_id_to_user_file_id[document_id]
|
||||
if not user_file_id:
|
||||
continue
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content))
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
@@ -914,10 +677,6 @@ def index_doc_batch(
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_file=doc_id_to_user_file_id.get(chunk.source_document.id, None),
|
||||
user_folder=doc_id_to_user_folder_id.get(
|
||||
chunk.source_document.id, None
|
||||
),
|
||||
boost=(
|
||||
ctx.id_to_db_doc_map[chunk.source_document.id].boost
|
||||
if chunk.source_document.id in ctx.id_to_db_doc_map
|
||||
@@ -999,11 +758,6 @@ def index_doc_batch(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_user_file_token_count__no_commit(
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# these documents can now be counted as part of the CC Pairs
|
||||
# document count, so we need to mark them as indexed
|
||||
# NOTE: even documents we skipped since they were already up
|
||||
@@ -1015,22 +769,12 @@ def index_doc_batch(
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
for user_file_id, raw_text in user_file_id_to_raw_text.items():
|
||||
# Use the dedicated function to store plaintext
|
||||
store_user_file_plaintext(
|
||||
user_file_id=user_file_id,
|
||||
plaintext_content=raw_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# save the chunk boost components to postgres
|
||||
update_chunk_boost_components__no_commit(
|
||||
chunk_data=updatable_chunk_data, db_session=db_session
|
||||
)
|
||||
|
||||
# Pause user file ccpairs
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = IndexingPipelineResult(
|
||||
@@ -1055,33 +799,13 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
all_search_settings = get_active_search_settings(db_session)
|
||||
if (
|
||||
all_search_settings.secondary
|
||||
and all_search_settings.secondary.status == IndexModelStatus.FUTURE
|
||||
):
|
||||
search_settings = all_search_settings.secondary
|
||||
else:
|
||||
search_settings = all_search_settings.primary
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
|
||||
enable_contextual_rag = (
|
||||
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
|
||||
)
|
||||
llm = None
|
||||
if enable_contextual_rag:
|
||||
llm = get_llm_for_contextual_rag(
|
||||
search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME,
|
||||
search_settings.contextual_rag_llm_provider
|
||||
or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER,
|
||||
)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass_config.multipass_indexing,
|
||||
enable_large_chunks=multipass_config.enable_large_chunks,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
# after every doc, update status in case there are a bunch of really long docs
|
||||
callback=callback,
|
||||
)
|
||||
@@ -1095,6 +819,4 @@ def build_indexing_pipeline(
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -49,15 +49,6 @@ class DocAwareChunk(BaseChunk):
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
# This is the number of tokens reserved for contextual RAG
|
||||
# in the chunk. doc_summary and chunk_context conbined should
|
||||
# contain at most this many tokens.
|
||||
contextual_rag_reserved_tokens: int
|
||||
# This is the summary for the document generated for contextual RAG
|
||||
doc_summary: str
|
||||
# This is the context for this chunk generated for contextual RAG
|
||||
chunk_context: str
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
large_chunk_id: int | None
|
||||
@@ -100,8 +91,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
user_file: int | None
|
||||
user_folder: int | None
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
|
||||
@@ -111,8 +100,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
index_chunk: IndexChunk,
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
user_file: int | None,
|
||||
user_folder: int | None,
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
@@ -122,8 +109,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
**index_chunk_data,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_file=user_file,
|
||||
user_folder=user_folder,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
@@ -169,9 +154,6 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
reduced_dimension: int | None = None
|
||||
|
||||
background_reindex_enabled: bool = True
|
||||
enable_contextual_rag: bool
|
||||
contextual_rag_llm_name: str | None = None
|
||||
contextual_rag_llm_provider: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
@@ -196,7 +178,6 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
|
||||
messages=processed_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice if tools else None,
|
||||
max_tokens=max_tokens,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -531,7 +531,6 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import LLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -155,40 +154,6 @@ def get_default_llm_with_vision(
|
||||
return None
|
||||
|
||||
|
||||
def llm_from_provider(
|
||||
model_name: str,
|
||||
llm_provider: LLMProvider,
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_llm_provider_view(db_session, model_provider)
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider with name {} found".format(model_provider))
|
||||
return llm_from_provider(
|
||||
model_name=model_name,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
@@ -214,9 +179,14 @@ def get_default_llms(
|
||||
raise ValueError("No fast default model name found")
|
||||
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return llm_from_provider(
|
||||
model_name=model,
|
||||
llm_provider=llm_provider,
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -30,20 +29,13 @@ from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
|
||||
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
@@ -52,10 +44,6 @@ from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_CONTEXT_TOKENS = 100
|
||||
ONE_MILLION = 1_000_000
|
||||
CHUNKS_PER_DOC_ESTIMATE = 5
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception,
|
||||
@@ -131,12 +119,7 @@ def _build_content(
|
||||
text_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_type
|
||||
in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.CSV,
|
||||
ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
|
||||
]
|
||||
|
||||
if not text_files:
|
||||
@@ -144,18 +127,7 @@ def _build_content(
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files:
|
||||
try:
|
||||
file_content = file.content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Try to decode as binary
|
||||
try:
|
||||
file_content, _, _ = read_pdf_file(io.BytesIO(file.content))
|
||||
except Exception:
|
||||
file_content = f"[Binary file content - {file.file_type} format]"
|
||||
logger.exception(
|
||||
f"Could not decode binary file content for file type: {file.file_type}"
|
||||
)
|
||||
# logger.warning(f"Could not decode binary file content for file type: {file.file_type}")
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
@@ -183,6 +155,7 @@ def build_content_with_imgs(
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if exclude_images or (not img_files and not img_urls):
|
||||
@@ -430,83 +403,19 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_contextual_cost(
|
||||
llm: LLM,
|
||||
) -> float:
|
||||
"""
|
||||
Approximate the cost of using the given LLM for indexing with Contextual RAG.
|
||||
|
||||
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
|
||||
and we assume that every chunk is maximized in terms of content and context.
|
||||
We also assume that every document is maximized in terms of content, as currently if
|
||||
a document is longer than a certain length, its summary is used instead of the full content.
|
||||
|
||||
We expect that the first assumption will overestimate more than the second one
|
||||
underestimates, so this should be a fairly conservative price estimate. Also,
|
||||
this does not account for the cost of documents that fit within a single chunk
|
||||
which do not get contextualized.
|
||||
"""
|
||||
|
||||
# calculate input costs
|
||||
num_tokens = ONE_MILLION
|
||||
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
|
||||
# on average.
|
||||
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
|
||||
num_input_tokens = 0
|
||||
num_output_tokens = 0
|
||||
|
||||
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
|
||||
return 0
|
||||
|
||||
if USE_CHUNK_SUMMARY:
|
||||
# Each per-chunk prompt includes:
|
||||
# - The prompt tokens
|
||||
# - the document tokens
|
||||
# - the chunk tokens
|
||||
|
||||
# for each chunk, we prompt the LLM with the contextual RAG prompt
|
||||
# and the full document content (or the doc summary, so this is an overestimate)
|
||||
num_input_tokens += num_input_chunks * (
|
||||
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
)
|
||||
|
||||
# in aggregate, each chunk content is used as a prompt input once
|
||||
# so the full input size is covered
|
||||
num_input_tokens += num_tokens
|
||||
|
||||
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
|
||||
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
|
||||
|
||||
# going over each doc once means all the tokens, plus the prompt tokens for
|
||||
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
|
||||
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
|
||||
# So, we include this unconditionally to overestimate.
|
||||
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
|
||||
|
||||
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
|
||||
model=llm.config.model_name,
|
||||
prompt_tokens=num_input_tokens,
|
||||
completion_tokens=num_output_tokens,
|
||||
)
|
||||
# Costs are in USD dollars per million tokens
|
||||
return usd_per_prompt + usd_per_completion
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
@@ -531,10 +440,14 @@ def get_llm_max_tokens(
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_input_tokens"]
|
||||
logger.debug(
|
||||
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
|
||||
)
|
||||
return max_tokens
|
||||
|
||||
if "max_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_tokens"]
|
||||
logger.debug(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
|
||||
return max_tokens
|
||||
|
||||
logger.error(f"No max tokens found for LLM: {model_name}")
|
||||
@@ -556,16 +469,21 @@ def get_llm_max_output_tokens(
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
else:
|
||||
pass
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if "max_output_tokens" in model_obj:
|
||||
max_output_tokens = model_obj["max_output_tokens"]
|
||||
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
|
||||
return max_output_tokens
|
||||
|
||||
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||
if "max_tokens" in model_obj:
|
||||
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
||||
logger.info(
|
||||
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
|
||||
)
|
||||
return max_output_tokens
|
||||
|
||||
logger.error(f"No max output tokens found for LLM: {model_name}")
|
||||
@@ -602,7 +520,7 @@ def get_max_input_tokens(
|
||||
)
|
||||
|
||||
if input_toks <= 0:
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -17,7 +16,6 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -99,13 +97,10 @@ from onyx.server.settings.api import basic_router as settings_router
|
||||
from onyx.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from onyx.server.user_documents.api import router as user_documents_router
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.setup import setup_multitenant_onyx
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -120,12 +115,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
file_handlers = [
|
||||
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
|
||||
]
|
||||
|
||||
setup_uvicorn_logger(shared_file_handlers=file_handlers)
|
||||
|
||||
|
||||
def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
if not isinstance(exc, RequestValidationError):
|
||||
@@ -308,7 +297,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, user_documents_router)
|
||||
include_router_with_global_prefix_prepended(application, folder_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
@@ -403,11 +391,6 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
if (
|
||||
AUTH_TYPE == AuthType.CLOUD
|
||||
or AUTH_TYPE == AuthType.BASIC
|
||||
or AUTH_TYPE == AuthType.GOOGLE_OAUTH
|
||||
):
|
||||
# Add refresh token endpoint for OAuth as well
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
@@ -431,14 +414,9 @@ def get_application() -> FastAPI:
|
||||
if LOG_ENDPOINT_LATENCY:
|
||||
add_latency_logging_middleware(application, logger)
|
||||
|
||||
add_onyx_request_id_middleware(application, "API", logger)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
|
||||
from tokenizers import Encoding # type: ignore
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -13,8 +11,6 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -71,27 +67,16 @@ class TiktokenTokenizer(BaseTokenizer):
|
||||
|
||||
class HuggingFaceTokenizer(BaseTokenizer):
|
||||
def __init__(self, model_name: str):
|
||||
self.encoder: Tokenizer = Tokenizer.from_pretrained(model_name)
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
def _safer_encode(self, string: str) -> Encoding:
|
||||
"""
|
||||
Encode a string using the HuggingFaceTokenizer, but if it fails,
|
||||
encode the string as ASCII and decode it back to a string. This helps
|
||||
in cases where the string has weird characters like \udeb4.
|
||||
"""
|
||||
try:
|
||||
return self.encoder.encode(string, add_special_tokens=False)
|
||||
except Exception:
|
||||
return self.encoder.encode(
|
||||
string.encode("ascii", "ignore").decode(), add_special_tokens=False
|
||||
)
|
||||
self.encoder = Tokenizer.from_pretrained(model_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self._safer_encode(string).ids
|
||||
return self.encoder.encode(string, add_special_tokens=False).ids
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return self._safer_encode(string).tokens
|
||||
return self.encoder.encode(string, add_special_tokens=False).tokens
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
@@ -174,26 +159,9 @@ def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) <= desired_length:
|
||||
return content
|
||||
|
||||
return tokenizer.decode(tokens[:desired_length])
|
||||
|
||||
|
||||
def tokenizer_trim_middle(
|
||||
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
if len(tokens) <= desired_length:
|
||||
return tokenizer.decode(tokens)
|
||||
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
|
||||
sep_tokens = tokenizer.encode(sep_str)
|
||||
slice_size = (desired_length - len(sep_tokens)) // 2
|
||||
assert slice_size > 0, "Slice size is not positive, desired length is too short"
|
||||
return (
|
||||
tokenizer.decode(tokens[:slice_size])
|
||||
+ sep_str
|
||||
+ tokenizer.decode(tokens[-slice_size:])
|
||||
)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
|
||||
@@ -42,7 +42,6 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.slack_bot import fetch_slack_bot
|
||||
@@ -595,7 +594,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
if event_type == "message":
|
||||
is_dm = event.get("channel_type") == "im"
|
||||
is_tagged = bot_tag_id and f"<@{bot_tag_id}>" in msg
|
||||
is_tagged = bot_tag_id and bot_tag_id in msg
|
||||
is_onyx_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
|
||||
|
||||
# OnyxBot should never respond to itself
|
||||
@@ -728,11 +727,7 @@ def build_request_details(
|
||||
event = cast(dict[str, Any], req.payload["event"])
|
||||
msg = cast(str, event["text"])
|
||||
channel = cast(str, event["channel"])
|
||||
# Check for both app_mention events and messages containing bot tag
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
tagged = (event.get("type") == "app_mention") or (
|
||||
event.get("type") == "message" and bot_tag_id and f"<@{bot_tag_id}>" in msg
|
||||
)
|
||||
tagged = event.get("type") == "app_mention"
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
sender_id = event.get("user") or None
|
||||
@@ -973,9 +968,6 @@ def _get_socket_client(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the SqlEngine
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
# Initialize the tenant handler which will manage tenant connections
|
||||
logger.info("Starting SlackbotHandler")
|
||||
tenant_handler = SlackbotHandler()
|
||||
|
||||
@@ -145,7 +145,7 @@ def update_emote_react(
|
||||
|
||||
def remove_onyx_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(web_client=client)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s*", "", message_str)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
def _check_for_url_in_block(block: Block) -> bool:
|
||||
|
||||
@@ -220,29 +220,3 @@ Chat History:
|
||||
|
||||
Based on the above, what is a short name to convey the topic of the conversation?
|
||||
""".strip()
|
||||
|
||||
# NOTE: the prompt separation is partially done for efficiency; previously I tried
|
||||
# to do it all in one prompt with sequential format() calls but this will cause a backend
|
||||
# error when the document contains any {} as python will expect the {} to be filled by
|
||||
# format() arguments
|
||||
CONTEXTUAL_RAG_PROMPT1 = """<document>
|
||||
{document}
|
||||
</document>
|
||||
Here is the chunk we want to situate within the whole document"""
|
||||
|
||||
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
Please give a short succinct context to situate this chunk within the overall document
|
||||
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
|
||||
context and nothing else. """
|
||||
|
||||
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
|
||||
|
||||
DOCUMENT_SUMMARY_PROMPT = """<document>
|
||||
{document}
|
||||
</document>
|
||||
Please give a short succinct summary of the entire document. Answer only with the succinct
|
||||
summary and nothing else. """
|
||||
|
||||
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
|
||||
|
||||
@@ -195,7 +195,7 @@ class RedisConnectorPermissionSync:
|
||||
),
|
||||
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
ignore_result=True,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
@@ -125,7 +125,6 @@ class TenantRedis(redis.Redis):
|
||||
"hset",
|
||||
"hdel",
|
||||
"ttl",
|
||||
"pttl",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter" or item == "sscan_iter":
|
||||
|
||||
@@ -87,9 +87,6 @@ def _create_indexable_chunks(
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_reference_ids=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=preprocessed_doc["content_embedding"],
|
||||
mini_chunk_embeddings=[],
|
||||
@@ -98,8 +95,6 @@ def _create_indexable_chunks(
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
user_file=None,
|
||||
user_folder=None,
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
image_file_name=None,
|
||||
|
||||
@@ -5,7 +5,6 @@ from onyx.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.chat_configs import PERSONAS_YAML
|
||||
from onyx.configs.chat_configs import PROMPTS_YAML
|
||||
from onyx.configs.chat_configs import USER_FOLDERS_YAML
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.document_set import get_or_create_document_set_by_name
|
||||
from onyx.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
@@ -16,29 +15,6 @@ from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_prompt_by_name
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
from onyx.db.user_documents import upsert_user_folder
|
||||
|
||||
|
||||
def load_user_folders_from_yaml(
|
||||
db_session: Session,
|
||||
user_folders_yaml: str = USER_FOLDERS_YAML,
|
||||
) -> None:
|
||||
with open(user_folders_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_user_folders = data.get("user_folders", [])
|
||||
for user_folder in all_user_folders:
|
||||
upsert_user_folder(
|
||||
db_session=db_session,
|
||||
id=user_folder.get("id"),
|
||||
name=user_folder.get("name"),
|
||||
description=user_folder.get("description"),
|
||||
created_at=user_folder.get("created_at"),
|
||||
user=user_folder.get("user"),
|
||||
files=user_folder.get("files"),
|
||||
assistants=user_folder.get("assistants"),
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def load_prompts_from_yaml(
|
||||
@@ -203,4 +179,3 @@ def load_chat_yamls(
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
load_user_folders_from_yaml(db_session)
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
user_folders:
|
||||
- id: -1
|
||||
name: "Recent Documents"
|
||||
description: "Documents uploaded by the user"
|
||||
files: []
|
||||
assistants: []
|
||||
@@ -49,7 +49,6 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
("/auth/oauth/callback", {"GET"}),
|
||||
# anonymous user on cloud
|
||||
("/tenants/anonymous-user", {"POST"}),
|
||||
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -219,7 +219,7 @@ def update_cc_pair_status(
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
client_app.control.revoke(index_payload.celery_task_id)
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
@@ -238,7 +238,7 @@ def update_cc_pair_status(
|
||||
db_session.commit()
|
||||
|
||||
# this speeds up the start of indexing by firing the check immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(tenant_id=tenant_id),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
@@ -376,7 +376,7 @@ def prune_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
client_app, cc_pair, db_session, r, tenant_id
|
||||
primary_app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -450,7 +450,7 @@ def sync_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -524,7 +524,7 @@ def sync_cc_pair_groups(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -634,7 +634,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user