mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
102 Commits
content-mo
...
max
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e823919892 | ||
|
|
2f3020a4d3 | ||
|
|
4bae1318bb | ||
|
|
11c3f44c76 | ||
|
|
cb38ac8a97 | ||
|
|
b2120b9f39 | ||
|
|
ccd372cc4a | ||
|
|
ea30f1de1e | ||
|
|
a7130681d9 | ||
|
|
04911db715 | ||
|
|
feae7d0cc4 | ||
|
|
ac19c64b3c | ||
|
|
03d5c30fd2 | ||
|
|
e988c13e1d | ||
|
|
dc18d53133 | ||
|
|
a1cef389aa | ||
|
|
db8d6ce538 | ||
|
|
e8370dcb24 | ||
|
|
9951fe13ba | ||
|
|
56f8ab927b | ||
|
|
cb5bbd3812 | ||
|
|
742d29e504 | ||
|
|
ecc155d082 | ||
|
|
0857e4809d | ||
|
|
22e00a1f5c | ||
|
|
0d0588a0c1 | ||
|
|
aab777f844 | ||
|
|
babbe7689a | ||
|
|
a123661c92 | ||
|
|
c554889baf | ||
|
|
f08fa878a6 | ||
|
|
d307534781 | ||
|
|
6f54791910 | ||
|
|
0d5497bb6b | ||
|
|
7648627503 | ||
|
|
927554d5ca | ||
|
|
7dcec6caf5 | ||
|
|
036648146d | ||
|
|
2aa4697ac8 | ||
|
|
bc9b4e4f45 | ||
|
|
178a64f298 | ||
|
|
c79f1edf1d | ||
|
|
7c8e23aa54 | ||
|
|
d37b427d52 | ||
|
|
a65fefd226 | ||
|
|
bb09bde519 | ||
|
|
0f6cf0fc58 | ||
|
|
fed06b592d | ||
|
|
8d92a1524e | ||
|
|
ecfea9f5ed | ||
|
|
b269f1ba06 | ||
|
|
30c878efa5 | ||
|
|
2024776c19 | ||
|
|
431316929c | ||
|
|
c5b9c6e308 | ||
|
|
73dd188b3f | ||
|
|
79b061abbc | ||
|
|
552f1ead4f | ||
|
|
17925b49e8 | ||
|
|
55fb5c3ca5 | ||
|
|
99546e4a4d | ||
|
|
c25d56f4a5 | ||
|
|
35f3f4f120 | ||
|
|
25b69a8aca | ||
|
|
1b7d710b2a | ||
|
|
ae3d3db3f4 | ||
|
|
fb79a9e700 | ||
|
|
587ba11bbc | ||
|
|
fce81ebb60 | ||
|
|
61facfb0a8 | ||
|
|
52b96854a2 | ||
|
|
d123713c00 | ||
|
|
775c847f82 | ||
|
|
6d330131fd | ||
|
|
0292ca2445 | ||
|
|
15dd1e72ca | ||
|
|
91c9be37c0 | ||
|
|
2a01c854a0 | ||
|
|
85ebadc8eb | ||
|
|
5dda53eec3 | ||
|
|
72bf427cc2 | ||
|
|
f421c6010b | ||
|
|
0b87549f35 | ||
|
|
06624a988d | ||
|
|
ae774105e3 | ||
|
|
4dafc3aa6d | ||
|
|
5d7d471823 | ||
|
|
61366df34c | ||
|
|
1a444245f6 | ||
|
|
c32d234491 | ||
|
|
07b68436cf | ||
|
|
293d1a4476 | ||
|
|
ba514aaaa2 | ||
|
|
f45798b5dd | ||
|
|
64ff5df083 | ||
|
|
cf1b7e7a93 | ||
|
|
63692a6bd3 | ||
|
|
934700b928 | ||
|
|
b1a7cff9e0 | ||
|
|
463340b8a1 | ||
|
|
ba82888e1e | ||
|
|
39465d3104 |
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
@@ -0,0 +1,209 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
@@ -9,6 +9,10 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
@@ -45,11 +49,16 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
775
.vscode/launch.template.jsonc
vendored
775
.vscode/launch.template.jsonc
vendored
@@ -6,396 +6,419 @@
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": [
|
||||
"--- Individual ---"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
}
|
||||
}
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Compound ---",
|
||||
"configurations": ["--- Individual ---"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Run All Onyx Services",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": [
|
||||
"run", "dev"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Individual ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Web Server",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"model_server.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Model Server Console"
|
||||
"console": "integratedTerminal",
|
||||
"consoleTitle": "Web Server Console"
|
||||
},
|
||||
{
|
||||
"name": "Model Server",
|
||||
"consoleName": "Model Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
"args": [
|
||||
"onyx.main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"args": ["model_server.main:app", "--reload", "--port", "9000"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
"consoleTitle": "Model Server Console"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
"consoleName": "API Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery primary Console"
|
||||
"args": ["onyx.main:app", "--reload", "--port", "8080"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/slack/listener.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery primary",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
"consoleTitle": "Celery primary Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery light",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=64",
|
||||
"--prefetch-multiplier=8",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Pytest Console"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Tasks ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "3",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart External Volumes and Containers",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_containers.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Celery jobs launched through a single background script (legacy)
|
||||
// Recommend using the "Celery (all)" compound launch instead.
|
||||
"name": "Background Jobs",
|
||||
"consoleName": "Background Jobs",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/dev_run_background_jobs.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
"type": "chrome",
|
||||
"request": "launch",
|
||||
"url": "http://localhost:3000",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ Edition features outside of personal development or testing purposes. Please rea
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = INFO
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -64,7 +68,7 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
# continue on error with individual tenant
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,20 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
"""duplicated no-harm user file migration
|
||||
|
||||
Revision ID: 6a804aeb4830
|
||||
Revises: 8e1ac4f39a9f
|
||||
Create Date: 2025-04-01 07:26:10.539362
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6a804aeb4830"
|
||||
down_revision = "8e1ac4f39a9f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Check if user_file table already exists
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
if not inspector.has_table("user_file"):
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,50 @@
|
||||
"""enable contextual retrieval
|
||||
|
||||
Revision ID: 8e1ac4f39a9f
|
||||
Revises: 9aadf32dfeb4
|
||||
Create Date: 2024-12-20 13:29:09.918661
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e1ac4f39a9f"
|
||||
down_revision = "9aadf32dfeb4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"enable_contextual_rag",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_provider",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "enable_contextual_rag")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_name")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_provider")
|
||||
113
backend/alembic/versions/9aadf32dfeb4_add_user_files.py
Normal file
113
backend/alembic/versions/9aadf32dfeb4_add_user_files.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""add user files
|
||||
|
||||
Revision ID: 9aadf32dfeb4
|
||||
Revises: 3781a5eb12cb
|
||||
Create Date: 2025-01-26 16:08:21.551022
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import datetime
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9aadf32dfeb4"
|
||||
down_revision = "3781a5eb12cb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
|
||||
),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("link_url", sa.String(), nullable=True),
|
||||
sa.Column("token_count", sa.Integer(), nullable=True),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True,
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_folder table
|
||||
op.create_table(
|
||||
"persona__user_folder",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Update existing records to have is_user_file=False instead of NULL
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the persona__user_folder table
|
||||
op.drop_table("persona__user_folder")
|
||||
# Drop the persona__user_file table
|
||||
op.drop_table("persona__user_file")
|
||||
# Drop the user_file table
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
@@ -0,0 +1,52 @@
|
||||
"""max_length_for_instruction_system_prompt
|
||||
|
||||
Revision ID: e995bdf0d6f7
|
||||
Revises: 8e1ac4f39a9f
|
||||
Create Date: 2025-04-01 18:32:45.123456
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e995bdf0d6f7"
|
||||
down_revision = "8e1ac4f39a9f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Alter system_prompt and task_prompt columns to have a maximum length of 8000 characters
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.String(8000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.String(8000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert system_prompt and task_prompt columns back to Text type
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(8000),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(8000),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -93,12 +93,12 @@ def _get_access_for_documents(
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
user_emails=non_ee_access.user_emails,
|
||||
user_groups=set(user_group_info.get(document_id, [])),
|
||||
access_map[document_id] = DocumentAccess.build(
|
||||
user_emails=list(non_ee_access.user_emails),
|
||||
user_groups=user_group_info.get(document_id, []),
|
||||
is_public=is_public_anywhere,
|
||||
external_user_emails=ext_u_emails,
|
||||
external_user_group_ids=ext_u_groups,
|
||||
external_user_emails=list(ext_u_emails),
|
||||
external_user_group_ids=list(ext_u_groups),
|
||||
)
|
||||
return access_map
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, OnyxContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Rules defined here:
|
||||
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
|
||||
"""
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
@@ -263,13 +264,11 @@ def _fetch_all_page_restrictions(
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -286,11 +285,9 @@ def _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
@@ -324,11 +321,9 @@ def _fetch_all_page_restrictions(
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
@@ -342,13 +337,12 @@ def _fetch_all_page_restrictions(
|
||||
)
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -387,7 +381,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions(
|
||||
yield from _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
@@ -34,7 +35,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -48,7 +49,6 @@ def gmail_doc_sync(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -60,17 +60,14 @@ def gmail_doc_sync(
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
document_external_access.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
|
||||
return document_external_access
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -161,7 +162,6 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -174,10 +174,7 @@ def gdrive_doc_sync(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
@@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
if channel_id not in channel_doc_map:
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
def _fetch_workspace_permissions(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
@@ -122,10 +95,37 @@ def _fetch_channel_permissions(
|
||||
return channel_permissions
|
||||
|
||||
|
||||
def _get_slack_document_access(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
yield DocExternalAccess(
|
||||
external_access=channel_permissions[channel_id],
|
||||
doc_id=doc_metadata.id,
|
||||
)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("_get_slack_document_access: Stop signal detected")
|
||||
|
||||
callback.progress("_get_slack_document_access", 1)
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -136,9 +136,12 @@ def slack_doc_sync(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
if not user_id_to_email_map:
|
||||
raise ValueError(
|
||||
"No user id to email map found. Please check to make sure that "
|
||||
"your Slack bot token has the `users:read.email` scope"
|
||||
)
|
||||
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
@@ -148,18 +151,8 @@ def slack_doc_sync(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
document_external_accesses = []
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
# No documents found for channel the channel_id
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
yield from _get_slack_document_access(
|
||||
cc_pair=cc_pair,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
@@ -23,7 +26,7 @@ DocSyncFuncType = Callable[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
@@ -65,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
|
||||
from ee.onyx.server.enterprise_settings.store import load_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||
@@ -131,31 +131,49 @@ def put_logo(
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
|
||||
file_io = file_store.read_file(filename, mode="b")
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
|
||||
detail="No logo file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No logotype file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
@basic_router.get("/logotype")
|
||||
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
|
||||
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
|
||||
return fetch_logotype_helper(db_session)
|
||||
|
||||
|
||||
@basic_router.get("/logo")
|
||||
def fetch_logo(
|
||||
is_logotype: bool = False, db_session: Session = Depends(get_session)
|
||||
) -> Response:
|
||||
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
|
||||
if is_logotype:
|
||||
return fetch_logotype_helper(db_session)
|
||||
|
||||
return fetch_logo_helper(db_session)
|
||||
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
|
||||
@@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
|
||||
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def load_settings() -> EnterpriseSettings:
|
||||
"""Loads settings data directly from DB. This should be used primarily
|
||||
for checking what is actually in the DB, aka for editing and saving back settings.
|
||||
|
||||
Runtime settings actually used by the application should be checked with
|
||||
load_runtime_settings as defaults may be applied at runtime.
|
||||
"""
|
||||
|
||||
dynamic_config_store = get_kv_store()
|
||||
try:
|
||||
settings = EnterpriseSettings(
|
||||
@@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings:
|
||||
|
||||
|
||||
def store_settings(settings: EnterpriseSettings) -> None:
|
||||
"""Stores settings directly to the kv store / db."""
|
||||
|
||||
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
|
||||
def load_runtime_settings() -> EnterpriseSettings:
|
||||
"""Loads settings from DB and applies any defaults or transformations for use
|
||||
at runtime.
|
||||
|
||||
Should not be stored back to the DB.
|
||||
"""
|
||||
enterprise_settings = load_settings()
|
||||
if not enterprise_settings.application_name:
|
||||
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
return enterprise_settings
|
||||
|
||||
|
||||
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
|
||||
|
||||
|
||||
@@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
|
||||
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
|
||||
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def is_valid_file_type(filename: str) -> bool:
|
||||
valid_extensions = (".png", ".jpg", ".jpeg")
|
||||
return filename.endswith(valid_extensions)
|
||||
@@ -116,3 +138,11 @@ def upload_logo(
|
||||
file_type=file_type,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_logo_filename() -> str:
|
||||
return _LOGO_FILENAME
|
||||
|
||||
|
||||
def get_logotype_filename() -> str:
|
||||
return _LOGOTYPE_FILENAME
|
||||
|
||||
@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) Reset token cookie
|
||||
3) The anonymous user cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
@@ -52,41 +52,55 @@ async def _get_tenant_id_from_request(
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
if token_data:
|
||||
tenant_id_from_payload = token_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
if tenant_id and not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(
|
||||
anonymous_user_cookie
|
||||
)
|
||||
tenant_id = anonymous_user_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
if not tenant_id or not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
|
||||
@@ -14,7 +14,6 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
|
||||
from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -56,25 +55,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
def _translate_doc_response_to_simple_doc(
|
||||
doc_response: QADocsResponse,
|
||||
) -> list[SimpleDoc]:
|
||||
return [
|
||||
SimpleDoc(
|
||||
id=doc.document_id,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
link=doc.link,
|
||||
blurb=doc.blurb,
|
||||
match_highlights=[
|
||||
highlight for highlight in doc.match_highlights if highlight
|
||||
],
|
||||
source_type=doc.source_type,
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
for doc in doc_response.top_documents
|
||||
]
|
||||
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
@@ -111,9 +91,6 @@ def _convert_packet_stream_to_response(
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
@@ -164,8 +163,6 @@ class ChatBasicResponse(BaseModel):
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
@@ -220,4 +217,3 @@ class OneShotQAResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
contexts: OnyxContexts | None = None
|
||||
|
||||
@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
logger.notice("Creating user from SAML login")
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -87,11 +87,15 @@ async def get_or_provision_tenant(
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
@@ -116,10 +120,6 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
@@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
@@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
@@ -504,8 +506,11 @@ async def setup_tenant(tenant_id: str) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
# Run Alembic migrations in a way that isolates it from the current event loop
|
||||
# Create a new event loop for this synchronous operation
|
||||
loop = asyncio.get_event_loop()
|
||||
# Use run_in_executor which properly isolates the thread execution
|
||||
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
@@ -559,7 +564,3 @@ async def assign_tenant_to_user(
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
|
||||
BIN
backend/hello-vmlinux.bin
Normal file
BIN
backend/hello-vmlinux.bin
Normal file
Binary file not shown.
@@ -65,11 +65,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
try:
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
|
||||
"This is not a critical error and the model server will continue to run."
|
||||
)
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
return DocumentAccess.build(
|
||||
doc_access = DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
@@ -26,6 +26,8 @@ def _get_access_for_document(
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
@@ -38,12 +40,12 @@ def get_access_for_document(
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
return DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -55,19 +57,18 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
doc_access = {}
|
||||
for document_id, user_emails, is_public in document_access_info:
|
||||
doc_access[document_id] = DocumentAccess.build(
|
||||
user_emails=[email for email in user_emails if email],
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
user_groups=[],
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# Sometimes the document has not been indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
|
||||
@@ -20,7 +20,7 @@ class ExternalAccess:
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
together. It's used for syncing document permissions to Vespa.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
@@ -56,34 +56,46 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
external_user_emails: set[str]
|
||||
external_user_group_ids: set[str]
|
||||
is_public: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
acl_set.add(prefix_user_email(user_email))
|
||||
|
||||
for group_name in self.user_groups:
|
||||
acl_set.add(prefix_user_group(group_name))
|
||||
|
||||
for external_user_email in self.external_user_emails:
|
||||
acl_set.add(prefix_user_email(external_user_email))
|
||||
|
||||
for external_group_id in self.external_user_group_ids:
|
||||
acl_set.add(prefix_external_group(external_group_id))
|
||||
|
||||
if self.is_public:
|
||||
acl_set.add(PUBLIC_DOC_PAT)
|
||||
|
||||
return acl_set
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
|
||||
|
||||
obj = object.__new__(cls)
|
||||
object.__setattr__(
|
||||
obj, "user_emails", {user_email for user_email in user_emails if user_email}
|
||||
)
|
||||
object.__setattr__(obj, "user_groups", set(user_groups))
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_emails",
|
||||
{external_email for external_email in external_user_emails},
|
||||
)
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_group_ids",
|
||||
{external_group_id for external_group_id in external_user_group_ids},
|
||||
)
|
||||
object.__setattr__(obj, "is_public", is_public)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
default_public_access = DocumentAccess.build(
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
@@ -24,7 +23,7 @@ def process_llm_stream(
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
|
||||
@@ -156,7 +156,6 @@ def generate_initial_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -183,7 +183,6 @@ def generate_validate_refined_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -57,7 +57,6 @@ def format_results(
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
@@ -59,9 +57,7 @@ def basic_use_tool_response(
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
initial_search_results.append(section_to_llm_doc(section))
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -321,8 +321,10 @@ def dispatch_separated(
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
accumulated_tokens = ""
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
for token in tokens:
|
||||
accumulated_tokens += cast(str, token.content)
|
||||
content = cast(str, token.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.image import MIMEImage
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
@@ -13,8 +14,13 @@ from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
@@ -56,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
width: 140px;
|
||||
height: auto;
|
||||
filter: brightness(1.1) contrast(1.2);
|
||||
border-radius: 8px;
|
||||
padding: 5px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
@@ -72,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
padding: 14px 24px;
|
||||
background-color: #0055FF;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
margin-top: 10px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
text-align: center;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
@@ -97,8 +112,8 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
src="cid:logo.png"
|
||||
alt="{application_name} Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -113,9 +128,8 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -125,17 +139,27 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
application_name: str | None,
|
||||
heading: str,
|
||||
message: str,
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
application_name=application_name,
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
@@ -146,10 +170,12 @@ def send_email(
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
inline_png: tuple[str, bytes] | None = None,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
@@ -158,11 +184,30 @@ def send_email(
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
# Add text part first (lowest priority)
|
||||
text_part = MIMEText(text_body, "plain")
|
||||
msg.attach(text_part)
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
if inline_png:
|
||||
# For HTML with images, create a multipart/related container
|
||||
related = MIMEMultipart("related")
|
||||
|
||||
# Add the HTML part to the related container
|
||||
html_part = MIMEText(html_body, "html")
|
||||
related.attach(html_part)
|
||||
|
||||
# Add image with proper Content-ID to the related container
|
||||
img = MIMEImage(inline_png[1], _subtype="png")
|
||||
img.add_header("Content-ID", f"<{inline_png[0]}>")
|
||||
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
|
||||
related.attach(img)
|
||||
|
||||
# Add the related part to the message (higher priority than text)
|
||||
msg.attach(related)
|
||||
else:
|
||||
# No images, just add HTML directly (higher priority than text)
|
||||
html_part = MIMEText(html_body, "html")
|
||||
msg.attach(html_part)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -174,8 +219,21 @@ def send_email(
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
"""This is templated but isn't meaningful for whitelabeling."""
|
||||
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Your {application_name} Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
@@ -184,23 +242,48 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Invitation to Join {application_name} Organization"
|
||||
heading = "You've Been Invited!"
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
@@ -226,19 +309,32 @@ def send_user_email_invite(
|
||||
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -248,27 +344,80 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Reset Your {application_name} Password"
|
||||
heading = "Reset Your Password"
|
||||
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
|
||||
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
|
||||
cta_text = "Reset Password"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
|
||||
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
new_organization: bool = False,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
if new_organization:
|
||||
link = add_url_params(link, {"first_user": "true"})
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Verify Your Email",
|
||||
message,
|
||||
)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
|
||||
211
backend/onyx/auth/oauth_refresher.py
Normal file
211
backend/onyx/auth/oauth_refresher.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Standard OAuth refresh token endpoints
|
||||
REFRESH_ENDPOINTS = {
|
||||
"google": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Keeping this as a utility function for potential future debugging,
|
||||
# but not using it in production code
|
||||
async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Utility function for testing - Sets an OAuth token to expire in a short time
|
||||
to facilitate testing of the refresh flow.
|
||||
Not used in production code.
|
||||
"""
|
||||
try:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error setting artificial expiration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not oauth_account.refresh_token:
|
||||
logger.warning(
|
||||
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
|
||||
)
|
||||
return False
|
||||
|
||||
provider = oauth_account.oauth_name
|
||||
if provider not in REFRESH_ENDPOINTS:
|
||||
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
REFRESH_ENDPOINTS[provider],
|
||||
data={
|
||||
"client_id": OAUTH_CLIENT_ID,
|
||||
"client_secret": OAUTH_CLIENT_SECRET,
|
||||
"refresh_token": oauth_account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to refresh OAuth token: Status {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
new_access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get(
|
||||
"refresh_token", oauth_account.refresh_token
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
# Calculate new expiry time if provided
|
||||
new_expires_at: Optional[int] = None
|
||||
if expires_in:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expires_in)
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
|
||||
if new_expires_at:
|
||||
updated_data["expires_at"] = new_expires_at
|
||||
|
||||
# Update oidc_expiry in user model if we're tracking it
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(
|
||||
new_expires_at, tz=timezone.utc
|
||||
)
|
||||
await user_manager.user_db.update(
|
||||
user, {"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error refreshing OAuth token: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return
|
||||
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Buffer time to refresh tokens before they expire (in seconds)
|
||||
buffer_seconds = 300 # 5 minutes
|
||||
|
||||
for oauth_account in user.oauth_accounts:
|
||||
# Skip accounts without refresh tokens
|
||||
if not oauth_account.refresh_token:
|
||||
continue
|
||||
|
||||
# If token is about to expire, refresh it
|
||||
if (
|
||||
oauth_account.expires_at
|
||||
and oauth_account.expires_at - now_timestamp < buffer_seconds
|
||||
):
|
||||
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
|
||||
success = await refresh_oauth_token(
|
||||
user, oauth_account, db_session, user_manager
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to refresh OAuth token. User may need to re-authenticate."
|
||||
)
|
||||
|
||||
|
||||
async def check_oauth_account_has_refresh_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an OAuth account has a refresh token.
|
||||
Returns True if a refresh token exists, False otherwise.
|
||||
"""
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return []
|
||||
|
||||
accounts_needing_refresh = []
|
||||
for oauth_account in user.oauth_accounts:
|
||||
has_refresh_token = await check_oauth_account_has_refresh_token(
|
||||
user, oauth_account
|
||||
)
|
||||
if not has_refresh_token:
|
||||
accounts_needing_refresh.append(oauth_account)
|
||||
|
||||
return accounts_needing_refresh
|
||||
@@ -5,12 +5,16 @@ import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -105,6 +109,7 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -355,7 +360,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -580,8 +584,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
user_count = await get_user_count()
|
||||
send_user_verification_email(
|
||||
user.email, token, new_organization=user_count == 1
|
||||
)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
@@ -593,7 +599,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
POSTGRES_DEFAULT_SCHEMA,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
@@ -687,16 +693,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
@@ -755,6 +765,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by extending its expiration time in Redis."""
|
||||
if token is None:
|
||||
# If no token provided, create a new one
|
||||
return await self.write_token(user)
|
||||
|
||||
redis = await get_async_redis_connection()
|
||||
token_key = f"{self.key_prefix}{token}"
|
||||
|
||||
# Check if token exists
|
||||
token_data_str = await redis.get(token_key)
|
||||
if not token_data_str:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Token exists, extend its lifetime
|
||||
token_data = json.loads(token_data_str)
|
||||
await redis.set(
|
||||
token_key,
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_redis_strategy() -> TenantAwareRedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> RefreshableDatabaseStrategy:
|
||||
return RefreshableDatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
@@ -805,6 +884,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
return router
|
||||
|
||||
def get_refresh_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for session token refreshing.
|
||||
"""
|
||||
# Import the oauth_refresher here to avoid circular imports
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
refresh_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_login_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
|
||||
)
|
||||
async def refresh(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
|
||||
get_user_manager
|
||||
),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> Response:
|
||||
try:
|
||||
user, token = user_token
|
||||
logger.info(f"Processing token refresh request for user {user.email}")
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
supports_refresh = hasattr(strategy, "refresh_token") and callable(
|
||||
getattr(strategy, "refresh_token")
|
||||
)
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
return await backend.transport.get_login_response(new_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing session token: {str(e)}")
|
||||
# Fallback to logout and login if refresh fails
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
|
||||
# Fallback: logout and login again
|
||||
logger.info(
|
||||
"Strategy doesn't support refresh - using logout/login flow"
|
||||
)
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Token refresh failed: {str(e)}",
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
@@ -1038,12 +1199,20 @@ def get_oauth_router(
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
|
||||
@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -173,6 +174,9 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
logger.exception(
|
||||
f"Marking attempt {attempt.id} as canceled due to validation error 2"
|
||||
)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@@ -235,7 +239,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
@@ -284,5 +288,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = setup_logger()
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_DIR = "/var/log/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
# we have a better implementation (backpressure, etc)
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
@@ -63,6 +64,15 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-folder-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
@@ -194,6 +204,16 @@ if not MULTI_TENANT:
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-process-memory",
|
||||
"task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -386,6 +389,8 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
credential_id_to_delete: int | None = None
|
||||
connector_id_to_delete: int | None = None
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -440,16 +445,35 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Store IDs before potentially expiring cc_pair
|
||||
connector_id_to_delete = cc_pair.connector_id
|
||||
credential_id_to_delete = cc_pair.credential_id
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
|
||||
# Flush to ensure document deletion happens before connector deletion
|
||||
db_session.flush()
|
||||
|
||||
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
|
||||
# related to the deleted DocumentByConnectorCredentialPair during commit
|
||||
db_session.expire(cc_pair)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
@@ -482,15 +506,15 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"connector={connector_id_to_delete} "
|
||||
f"credential={credential_id_to_delete} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
@@ -540,7 +564,7 @@ def validate_connector_deletion_fences(
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
queued_upsert_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
@@ -627,7 +651,7 @@ def validate_connector_deletion_fence(
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
if member_str in queued_upsert_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
@@ -17,6 +17,7 @@ from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
@@ -46,7 +47,6 @@ from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -64,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -105,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -285,7 +289,7 @@ def try_creating_permissions_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
@@ -420,12 +424,7 @@ def connector_permission_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -453,23 +452,23 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
document_external_accesses = doc_sync_func(cc_pair, callback)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
tasks_generated = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
tasks_generated += 1
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
@@ -881,6 +880,21 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"id": payload.id if payload else None,
|
||||
"total_docs": initial if initial is not None else 0,
|
||||
"remaining_docs": remaining,
|
||||
"synced_docs": (initial - remaining) if initial is not None else 0,
|
||||
"is_complete": remaining == 0,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -272,7 +271,7 @@ def try_creating_external_group_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
@@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task(
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -364,6 +365,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_indexing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
@@ -401,7 +403,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
redis_client.set(
|
||||
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
|
||||
1,
|
||||
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
|
||||
)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
@@ -428,7 +434,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
cc_pairs = fetch_connector_credential_pairs(
|
||||
db_session, include_user_files=True
|
||||
)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
@@ -447,12 +455,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
not search_settings_instance.status.is_current()
|
||||
and not search_settings_instance.background_reindex_enabled
|
||||
):
|
||||
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
|
||||
|
||||
continue
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Skipping fenced connector: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -460,6 +474,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
@@ -473,7 +490,20 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -512,6 +542,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Failed to create indexing task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -1144,6 +1180,9 @@ def connector_indexing_proxy_task(
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
|
||||
@@ -371,6 +371,7 @@ def should_index(
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
@@ -380,6 +381,9 @@ def should_index(
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
|
||||
)
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
@@ -388,19 +392,31 @@ def should_index(
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -412,6 +428,9 @@ def should_index(
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
|
||||
)
|
||||
return False
|
||||
|
||||
if search_settings_instance.status.is_current():
|
||||
@@ -424,11 +443,16 @@ def should_index(
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
print(
|
||||
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
|
||||
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -508,6 +532,13 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
@@ -518,7 +549,7 @@ def try_creating_indexing_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -19,6 +20,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -39,8 +41,10 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
@@ -904,3 +908,93 @@ def monitor_celery_queues_helper(
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
|
||||
"""Memory monitoring"""
|
||||
|
||||
|
||||
def _get_cmdline_for_process(process: psutil.Process) -> str | None:
|
||||
try:
|
||||
return " ".join(process.cmdline())
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_process_memory(self: Task, *, tenant_id: str) -> None:
|
||||
"""
|
||||
Task to monitor memory usage of supervisor-managed processes.
|
||||
This periodically checks the memory usage of processes and logs information
|
||||
in a standardized format.
|
||||
|
||||
The task looks for processes managed by supervisor and logs their
|
||||
memory usage statistics. This is useful for monitoring memory consumption
|
||||
over time and identifying potential memory leaks.
|
||||
"""
|
||||
# don't run this task in multi-tenant mode, have other, better means of monitoring
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get all supervisor-managed processes
|
||||
supervisor_processes: dict[int, str] = {}
|
||||
|
||||
# Map cmd line elements to more readable process names
|
||||
process_type_mapping = {
|
||||
"--hostname=primary": "primary",
|
||||
"--hostname=light": "light",
|
||||
"--hostname=heavy": "heavy",
|
||||
"--hostname=indexing": "indexing",
|
||||
"--hostname=monitoring": "monitoring",
|
||||
"beat": "beat",
|
||||
"slack/listener.py": "slack",
|
||||
}
|
||||
|
||||
# Find all python processes that are likely celery workers
|
||||
for proc in psutil.process_iter():
|
||||
cmdline = _get_cmdline_for_process(proc)
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Match supervisor-managed processes
|
||||
for process_name, process_type in process_type_mapping.items():
|
||||
if process_name in cmdline:
|
||||
if process_type in supervisor_processes.values():
|
||||
task_logger.error(
|
||||
f"Duplicate process type for type {process_type} "
|
||||
f"with cmd {cmdline} with pid={proc.pid}."
|
||||
)
|
||||
continue
|
||||
|
||||
supervisor_processes[proc.pid] = process_type
|
||||
break
|
||||
|
||||
if len(supervisor_processes) != len(process_type_mapping):
|
||||
task_logger.error(
|
||||
"Missing processes: "
|
||||
f"{set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
|
||||
)
|
||||
|
||||
# Log memory usage for each process
|
||||
for pid, process_type in supervisor_processes.items():
|
||||
try:
|
||||
emit_process_memory(pid, process_type, {})
|
||||
except psutil.NoSuchProcess:
|
||||
# Process may have terminated since we obtained the list
|
||||
continue
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error monitoring process {pid}: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in monitor_process_memory task")
|
||||
|
||||
@@ -6,6 +6,7 @@ from tenacity import wait_random_exponential
|
||||
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
|
||||
|
||||
class RetryDocumentIndex:
|
||||
@@ -52,11 +53,13 @@ class RetryDocumentIndex:
|
||||
*,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=fields,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
@@ -164,6 +164,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_with_user_files,
|
||||
)
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.user_documents import fetch_user_files_for_documents
|
||||
from onyx.db.user_documents import fetch_user_folders_for_documents
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check for documents that need user file folder metadata updates.
|
||||
This task fetches all connector credential pairs with user files, gets the documents
|
||||
associated with them, and updates the user file and folder metadata in Vespa.
|
||||
"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all connector credential pairs that have user files
|
||||
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
|
||||
|
||||
if not cc_pairs:
|
||||
task_logger.info("No connector credential pairs with user files found")
|
||||
return True
|
||||
|
||||
# Get all documents associated with these cc_pairs
|
||||
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No documents found for connector credential pairs with user files"
|
||||
)
|
||||
return True
|
||||
|
||||
# Fetch current user file and folder IDs for these documents
|
||||
doc_id_to_user_file_id = fetch_user_files_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
# Update Vespa metadata for each document
|
||||
for doc_id in document_ids:
|
||||
user_file_id = doc_id_to_user_file_id.get(doc_id)
|
||||
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
|
||||
|
||||
if user_file_id is not None or user_folder_id is not None:
|
||||
# Schedule a task to update the document metadata
|
||||
update_user_file_folder_metadata.apply_async(
|
||||
args=(doc_id,), # Use tuple instead of list for args
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
"user_file_id": user_file_id,
|
||||
"user_folder_id": user_folder_id,
|
||||
},
|
||||
queue="vespa_metadata_sync",
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Scheduled metadata updates for {len(document_ids)} documents. "
|
||||
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
|
||||
return False
|
||||
finally:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def get_documents_for_cc_pairs(
|
||||
cc_pairs: List[ConnectorCredentialPair], db_session: Session
|
||||
) -> List[str]:
|
||||
"""Get all document IDs associated with the given connector credential pairs."""
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
|
||||
|
||||
# Query to get document IDs from DocumentByConnectorCredentialPair
|
||||
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
|
||||
doc_cc_pairs = (
|
||||
db_session.query(Document.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(
|
||||
ConnectorCredentialPair.id.in_(cc_pair_ids),
|
||||
ConnectorCredentialPair.connector_id
|
||||
== DocumentByConnectorCredentialPair.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [doc_id for (doc_id,) in doc_cc_pairs]
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def update_user_file_folder_metadata(
|
||||
self: Task,
|
||||
document_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_file_id: int | None,
|
||||
user_folder_id: int | None,
|
||||
) -> bool:
|
||||
"""Updates the user file and folder metadata for a document in Vespa."""
|
||||
start = time.monotonic()
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
return False
|
||||
|
||||
# Create user fields object with file and folder IDs
|
||||
user_fields = VespaDocumentUserFields(
|
||||
user_file_id=str(user_file_id) if user_file_id is not None else None,
|
||||
user_folder_id=str(user_folder_id)
|
||||
if user_folder_id is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=None, # We're only updating user fields
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=user_file_folder_sync "
|
||||
f"user_file_id={user_file_id} "
|
||||
f"user_folder_id={user_folder_id} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
return True
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
task_logger.exception(
|
||||
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -80,7 +80,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
|
||||
# Useful for debugging timing issues with reacquisitions.
|
||||
# TODO: remove once more generalized logging is in place
|
||||
task_logger.info("check_for_vespa_sync_task started")
|
||||
|
||||
time_start = time.monotonic()
|
||||
@@ -572,6 +573,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
|
||||
@@ -6,6 +6,8 @@ from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -16,7 +18,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.object_size_check import deep_getsizeof
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
|
||||
@@ -52,7 +53,7 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int
|
||||
db_session: Session, index_attempt_id: int, connector: BaseConnector
|
||||
) -> ConnectorCheckpoint | None:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
@@ -60,6 +61,8 @@ def load_checkpoint(
|
||||
try:
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointConnector):
|
||||
return connector.validate_checkpoint_json(checkpoint_data)
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
except RuntimeError:
|
||||
return None
|
||||
@@ -71,6 +74,7 @@ def get_latest_valid_checkpoint(
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
connector: BaseConnector,
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
@@ -105,7 +109,7 @@ def get_latest_valid_checkpoint(
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
return connector.build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
@@ -113,12 +117,13 @@ def get_latest_valid_checkpoint(
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
connector=connector,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
@@ -193,7 +198,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
content_size = deep_getsizeof(checkpoint.model_dump())
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -32,8 +31,11 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
@@ -46,8 +48,6 @@ from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -271,7 +274,6 @@ def _run_indexing(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
# search_settings = index_attempt_start.search_settings
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
ctx = RunIndexingContext(
|
||||
@@ -387,6 +389,7 @@ def _run_indexing(
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
@@ -405,7 +408,7 @@ def _run_indexing(
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
@@ -413,6 +416,7 @@ def _run_indexing(
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
@@ -433,7 +437,7 @@ def _run_indexing(
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
@@ -568,6 +572,22 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_num": batch_num,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
@@ -583,6 +603,30 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Add telemetry for completed indexing
|
||||
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt_start.search_settings_id
|
||||
)
|
||||
final_progress = redis_connector_index.get_progress() or 0
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_count": batch_num,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
"redis_progress": final_progress,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
@@ -593,24 +637,58 @@ def _run_indexing(
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -673,6 +751,7 @@ def _run_indexing(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.gpu_utils import fast_gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,7 +88,9 @@ class Answer:
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
allow_agent_reranking = (
|
||||
fast_gpu_status_request(indexing=False) or using_cloud_reranking
|
||||
)
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
|
||||
@@ -127,6 +127,10 @@ class StreamStopInfo(SubQuestionIdentifier):
|
||||
return data
|
||||
|
||||
|
||||
class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
@@ -194,17 +198,6 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class OnyxContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class OnyxContexts(BaseModel):
|
||||
contexts: list[OnyxContext]
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
@@ -270,7 +263,6 @@ class PersonaOverrideConfig(BaseModel):
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| OnyxContexts
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
@@ -37,6 +36,7 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -52,6 +52,7 @@ from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
@@ -65,6 +66,7 @@ from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import attach_files_to_chat_message
|
||||
from onyx.db.chat import create_db_search_doc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import create_search_doc_from_user_file
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_db_search_doc_by_id
|
||||
@@ -73,6 +75,7 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
@@ -80,12 +83,16 @@ from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_all_user_file_files
|
||||
from onyx.file_store.utils import load_all_user_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -98,6 +105,7 @@ from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
@@ -130,7 +138,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -176,11 +183,14 @@ def _handle_search_tool_response_summary(
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
dedupe_docs: bool = False,
|
||||
user_files: list[UserFile] | None = None,
|
||||
loaded_user_files: list[InMemoryChatFile] | None = None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
is_extended = isinstance(packet, ExtendedToolResponse)
|
||||
dropped_inds = None
|
||||
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
@@ -194,9 +204,31 @@ def _handle_search_tool_response_summary(
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id not in doc_ids:
|
||||
associated_chat_file = None
|
||||
if loaded_user_files is not None:
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
@@ -254,7 +286,10 @@ def _handle_internet_search_tool_response_summary(
|
||||
|
||||
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
tools: list[Tool],
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
@@ -262,8 +297,11 @@ def _get_force_search_settings(
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
if new_msg_req.force_user_file_search:
|
||||
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
|
||||
else:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
@@ -273,12 +311,25 @@ def _get_force_search_settings(
|
||||
else None
|
||||
)
|
||||
|
||||
# Create override_kwargs for the search tool if user_file_ids are provided
|
||||
override_kwargs = None
|
||||
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.force_user_file_search,
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
@@ -291,15 +342,22 @@ def _get_force_search_settings(
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
return ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
|
||||
)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| OnyxContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -313,6 +371,7 @@ ChatPacket = (
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
| UserKnowledgeFilePacket
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -358,6 +417,10 @@ def stream_chat_message_objects(
|
||||
llm: LLM
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
ordered_user_files = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
@@ -537,6 +600,70 @@ def stream_chat_message_objects(
|
||||
)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
user_folder_ids = new_msg_req.user_folder_ids or []
|
||||
|
||||
if persona.user_files:
|
||||
for file in persona.user_files:
|
||||
user_file_ids.append(file.id)
|
||||
if persona.user_folders:
|
||||
for folder in persona.user_folders:
|
||||
user_folder_ids.append(folder.id)
|
||||
|
||||
# Initialize flag for user file search
|
||||
use_search_for_user_files = False
|
||||
|
||||
user_files: list[InMemoryChatFile] | None = None
|
||||
search_for_ordering_only = False
|
||||
user_file_files: list[UserFile] | None = None
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Load user files
|
||||
user_files = load_all_user_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
user_file_files = load_all_user_file_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
# Store mapping of file_id to file for later reordering
|
||||
if user_files:
|
||||
file_id_to_user_file = {file.file_id: file for file in user_files}
|
||||
|
||||
# Calculate token count for the files
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
|
||||
total_tokens = calculate_user_files_token_count(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate available tokens for documents based on prompt, user input, etc.
|
||||
available_tokens = compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=message_text, # Use the actual user message
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
|
||||
)
|
||||
|
||||
# ALWAYS use search for user files, but track if we need it for context or just ordering
|
||||
use_search_for_user_files = True
|
||||
# If files are small enough for context, we'll just use search for ordering
|
||||
search_for_ordering_only = total_tokens <= available_tokens
|
||||
|
||||
if search_for_ordering_only:
|
||||
# Add original user files to context since they fit
|
||||
if user_files:
|
||||
latest_query_files.extend(user_files)
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -679,8 +806,10 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
user_knowledge_present=bool(user_files or user_folder_ids),
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
use_file_search=new_msg_req.force_user_file_search,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
@@ -710,17 +839,138 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
force_use_tool = _get_force_search_settings(
|
||||
new_msg_req, tools, user_file_ids, user_folder_ids
|
||||
)
|
||||
|
||||
# Set force_use if user files exceed token limit
|
||||
if use_search_for_user_files:
|
||||
try:
|
||||
# Check if search tool is available in the tools list
|
||||
search_tool_available = any(
|
||||
isinstance(tool, SearchTool) for tool in tools
|
||||
)
|
||||
|
||||
# If no search tool is available, add one
|
||||
if not search_tool_available:
|
||||
logger.info("No search tool available, creating one for user files")
|
||||
# Create a basic search tool config
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
)
|
||||
|
||||
# Create and add the search tool
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
# Add the search tool to the tools list
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.info(
|
||||
"Added search tool for user files that exceed token limit"
|
||||
)
|
||||
|
||||
# Now set force_use_tool.force_use to True
|
||||
force_use_tool.force_use = True
|
||||
force_use_tool.tool_name = SearchTool._NAME
|
||||
|
||||
# Set query argument if not already set
|
||||
if not force_use_tool.args:
|
||||
force_use_tool.args = {"query": final_msg.message}
|
||||
|
||||
# Pass the user file IDs to the search tool
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Create a BaseFilters object with user_file_ids
|
||||
if not retrieval_options:
|
||||
retrieval_options = RetrievalDetails()
|
||||
if not retrieval_options.filters:
|
||||
retrieval_options.filters = BaseFilters()
|
||||
|
||||
# Set user file and folder IDs in the filters
|
||||
retrieval_options.filters.user_file_ids = user_file_ids
|
||||
retrieval_options.filters.user_folder_ids = user_folder_ids
|
||||
|
||||
# Create override kwargs for the search tool
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
|
||||
)
|
||||
|
||||
# Set the override kwargs in the force_use_tool
|
||||
force_use_tool.override_kwargs = override_kwargs
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Configured search tool with optimized settings for ordering-only"
|
||||
)
|
||||
logger.info(
|
||||
"Fast path: Skipping reranking and query analysis for ordering-only mode"
|
||||
)
|
||||
logger.info(
|
||||
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Configured search tool to use ",
|
||||
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error configuring search tool for user files: {str(e)}"
|
||||
)
|
||||
use_search_for_user_files = False
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
if not use_search_for_user_files and user_files:
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
|
||||
)
|
||||
for file in user_files
|
||||
]
|
||||
)
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
|
||||
)
|
||||
|
||||
search_request = SearchRequest(
|
||||
query=final_msg.message,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
LLMEvaluationType.SKIP
|
||||
if search_for_ordering_only
|
||||
else (
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
)
|
||||
),
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
@@ -739,7 +989,6 @@ def stream_chat_message_objects(
|
||||
),
|
||||
)
|
||||
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
@@ -808,8 +1057,22 @@ def stream_chat_message_objects(
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
continue
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
@@ -819,16 +1082,91 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
False
|
||||
if search_for_ordering_only
|
||||
else (
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
)
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=user_files
|
||||
if search_for_ordering_only
|
||||
else [],
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if (
|
||||
search_for_ordering_only
|
||||
and user_files
|
||||
and info.qa_docs_response
|
||||
):
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
import time
|
||||
|
||||
ordering_start = time.time()
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Found {len(doc_order)} files from search results"
|
||||
)
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id
|
||||
for f_id in file_id_to_user_file.keys()
|
||||
if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Added {len(missing_files)} missing files to the end"
|
||||
)
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id]
|
||||
for f_id in doc_order
|
||||
if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
time.time() - ordering_start
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
continue
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
@@ -918,8 +1256,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(OnyxContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
@@ -940,7 +1276,7 @@ def stream_chat_message_objects(
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -1022,10 +1358,16 @@ def stream_chat_message_objects(
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
tool_name=info.tool_result.tool_name,
|
||||
tool_arguments=info.tool_result.tool_args,
|
||||
tool_result=info.tool_result.tool_result,
|
||||
tool_id=tool_name_to_tool_id.get(info.tool_result.tool_name, 0)
|
||||
if info.tool_result
|
||||
else None,
|
||||
tool_name=info.tool_result.tool_name if info.tool_result else None,
|
||||
tool_arguments=info.tool_result.tool_args
|
||||
if info.tool_result
|
||||
else None,
|
||||
tool_result=info.tool_result.tool_result
|
||||
if info.tool_result
|
||||
else None,
|
||||
)
|
||||
if info.tool_result
|
||||
else None
|
||||
@@ -1069,6 +1411,8 @@ def stream_chat_message_objects(
|
||||
prev_message = next_answer_message
|
||||
|
||||
logger.debug("Committing messages")
|
||||
# Explicitly update the timestamp on the chat session
|
||||
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
||||
|
||||
@@ -19,6 +19,7 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
|
||||
@@ -153,6 +153,8 @@ def _apply_pruning(
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
sections = _remove_sections_to_ignore(sections=sections)
|
||||
|
||||
section_idx_token_count: dict[int, int] = {}
|
||||
|
||||
final_section_ind = None
|
||||
total_tokens = 0
|
||||
for ind, section in enumerate(sections):
|
||||
@@ -202,10 +204,20 @@ def _apply_pruning(
|
||||
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_token_count
|
||||
section_idx_token_count[ind] = section_token_count
|
||||
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
|
||||
try:
|
||||
logger.debug(f"Number of documents after pruning: {ind + 1}")
|
||||
logger.debug("Number of tokens per document (pruned):")
|
||||
for x, y in section_idx_token_count.items():
|
||||
logger.debug(f"{x + 1}: {y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging prune statistics: {e}")
|
||||
|
||||
if final_section_ind is not None:
|
||||
if is_manually_selected_docs or use_sections:
|
||||
if final_section_ind != len(sections) - 1:
|
||||
@@ -301,6 +313,10 @@ def prune_sections(
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
assert (
|
||||
len(set([chunk.document_id for chunk in chunks])) == 1
|
||||
), "One distinct document must be passed into merge_doc_chunks"
|
||||
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
@@ -358,6 +374,26 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
try:
|
||||
num_original_sections = len(sections)
|
||||
num_original_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in sections])
|
||||
)
|
||||
num_merged_sections = len(new_sections)
|
||||
num_merged_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in new_sections])
|
||||
)
|
||||
logger.debug(
|
||||
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
|
||||
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
|
||||
)
|
||||
|
||||
logger.debug("Number of chunks per document (new ranking):")
|
||||
for x, y in enumerate(new_sections):
|
||||
logger.debug(f"{x + 1}: {len(y.chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging merge statistics: {e}")
|
||||
|
||||
return new_sections
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Sequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
@@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
|
||||
@@ -180,6 +180,10 @@ def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
if force_use_tool.override_kwargs is not None:
|
||||
tool_args["override_kwargs"] = force_use_tool.override_kwargs
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
@@ -33,6 +35,10 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
# Controls whether to allow admin query history reports with:
|
||||
# 1. associated user emails
|
||||
# 2. anonymized user emails
|
||||
# 3. no queries
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
@@ -153,10 +159,9 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
|
||||
|
||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
@@ -165,7 +170,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
@@ -341,8 +346,8 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
@@ -380,10 +385,27 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
|
||||
def get_current_tz_offset() -> int:
|
||||
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||
# remove tzinfo to compare non-timezone-aware objects.
|
||||
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
# For the default value, we assume that the user's local timezone is more likely to be
|
||||
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
@@ -414,6 +436,9 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -470,6 +495,11 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
)
|
||||
# Enable contextual retrieval
|
||||
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
|
||||
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
@@ -511,6 +541,17 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
|
||||
# Average summary embeddings for contextual rag (not yet implemented)
|
||||
AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
@@ -667,3 +708,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
|
||||
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
|
||||
@@ -3,6 +3,10 @@ import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -40,6 +44,7 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Onyx as a search engine."
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
@@ -97,6 +102,8 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
@@ -174,6 +181,7 @@ class DocumentSource(str, Enum):
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
HIGHSPOT = "highspot"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
@@ -263,6 +271,7 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@@ -303,6 +312,7 @@ class OnyxCeleryQueues:
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
@@ -321,6 +331,7 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
@@ -376,6 +387,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
@@ -388,6 +400,10 @@ class OnyxCeleryTask:
|
||||
)
|
||||
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
|
||||
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -395,6 +411,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
@@ -402,9 +419,7 @@ class OnyxCeleryTask:
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
|
||||
@@ -65,19 +65,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
ONE_HOUR = 3600
|
||||
|
||||
|
||||
class ConfluenceConnector(
|
||||
@@ -114,6 +102,7 @@ class ConfluenceConnector(
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@@ -158,6 +147,9 @@ class ConfluenceConnector(
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -203,7 +195,6 @@ class ConfluenceConnector(
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
@@ -233,7 +224,9 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
@@ -264,6 +257,7 @@ class ConfluenceConnector(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
@@ -304,13 +298,14 @@ class ConfluenceConnector(
|
||||
if "version" in page and "by" in page["version"]:
|
||||
author = page["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
primary_owners.append(BasicExpertInfo(display_name=display_name))
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners.append(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
id=page_url,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -364,15 +359,18 @@ class ConfluenceConnector(
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
):
|
||||
logger.info(f"Skipping attachment: {attachment['title']}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
|
||||
# Attempt to get textual content or image summarization:
|
||||
try:
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
@@ -420,7 +418,17 @@ class ConfluenceConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
try:
|
||||
return self._fetch_document_batches(start, end)
|
||||
except Exception as e:
|
||||
if "field 'updated' is invalid" in str(e) and start is not None:
|
||||
logger.warning(
|
||||
"Confluence says we provided an invalid 'updated' field. This may indicate"
|
||||
"a real issue, but can also appear during edge cases like daylight"
|
||||
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
|
||||
)
|
||||
return self._fetch_document_batches(start - ONE_HOUR, end)
|
||||
raise
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
|
||||
@@ -498,10 +498,12 @@ class OnyxConfluence:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
|
||||
"but we have logic to work around it - don't worry this isn't"
|
||||
f" causing an issue. Start: {new_start}, Previous Start: "
|
||||
f"{previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
|
||||
@@ -112,6 +112,7 @@ def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
allow_images: bool,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
@@ -119,7 +120,7 @@ def process_attachment(
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return AttachmentProcessingResult(
|
||||
@@ -138,7 +139,14 @@ def process_attachment(
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if not media_type.startswith("image/"):
|
||||
if media_type.startswith("image/"):
|
||||
if not allow_images:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error="Image downloading is not enabled",
|
||||
)
|
||||
else:
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
@@ -294,6 +302,7 @@ def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
allow_images: bool,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
@@ -309,7 +318,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_id)
|
||||
result = process_attachment(confluence_client, attachment, page_id, allow_images)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -2,6 +2,8 @@ import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
@@ -19,8 +21,10 @@ logger = setup_logger()
|
||||
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
@@ -29,20 +33,20 @@ class CheckpointOutputWrapper:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: ConnectorCheckpoint | None = None
|
||||
self.next_checkpoint: CT | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> CheckpointOutput[CT]:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
@@ -64,7 +68,7 @@ class CheckpointOutputWrapper:
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
class ConnectorRunner(Generic[CT]):
|
||||
"""
|
||||
Handles:
|
||||
- Batching
|
||||
@@ -85,11 +89,9 @@ class ConnectorRunner:
|
||||
self.doc_batch: list[Document] = []
|
||||
|
||||
def run(
|
||||
self, checkpoint: ConnectorCheckpoint
|
||||
self, checkpoint: CT
|
||||
) -> Generator[
|
||||
tuple[
|
||||
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
|
||||
],
|
||||
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
@@ -105,9 +107,9 @@ class ConnectorRunner:
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
next_checkpoint: ConnectorCheckpoint | None = None
|
||||
next_checkpoint: CT | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
@@ -132,7 +134,7 @@ class ConnectorRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint = self.connector.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
|
||||
@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
@@ -69,7 +70,9 @@ def _process_egnyte_file(
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -30,6 +31,7 @@ from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
@@ -117,6 +119,7 @@ def identify_connector_class(
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
@@ -182,6 +185,8 @@ def instantiate_connector(
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), subfile, metadata
|
||||
elif is_valid_file_ext(extension):
|
||||
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -122,7 +123,7 @@ def _process_file(
|
||||
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
|
||||
return []
|
||||
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
@@ -219,24 +220,34 @@ def _process_file(
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images.
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Extract text and images from the file
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
extraction_result = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
)
|
||||
|
||||
# Merge file-specific metadata (from file content) with provided metadata
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
|
||||
)
|
||||
metadata.update(extraction_result.metadata)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if text_content.strip():
|
||||
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
|
||||
if extraction_result.text_content.strip():
|
||||
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
|
||||
sections.append(
|
||||
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
|
||||
)
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
extraction_result.embedded_images, start=1
|
||||
):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image reference
|
||||
try:
|
||||
|
||||
@@ -45,6 +45,8 @@ _FIREFLIES_API_QUERY = """
|
||||
}
|
||||
"""
|
||||
|
||||
ONE_MINUTE = 60
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections: List[TextSection] = []
|
||||
@@ -106,6 +108,8 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
)
|
||||
|
||||
|
||||
# If not all transcripts are being indexed, try using a more-recently-generated
|
||||
# API key.
|
||||
class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
@@ -191,6 +195,9 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
# add some leeway to account for any timezone funkiness and/or bad handling
|
||||
# of start time on the Fireflies side
|
||||
start = max(0, start - ONE_MINUTE)
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ITEMS_PER_PAGE = 100
|
||||
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
|
||||
def _get_batch_rate_limited(
|
||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Any]:
|
||||
) -> list[PullRequest | Issue]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
|
||||
)
|
||||
|
||||
|
||||
def _batch_github_objects(
|
||||
git_objs: PaginatedList, github_client: Github, batch_size: int
|
||||
) -> Iterator[list[Any]]:
|
||||
page_num = 0
|
||||
while True:
|
||||
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
|
||||
page_num += 1
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for mini_batch in batch_generator(batch, batch_size=batch_size):
|
||||
yield mini_batch
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None,
|
||||
metadata={
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class GithubConnector(LoadConnector, PollConnector):
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
ISSUES = "issues"
|
||||
|
||||
|
||||
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||
stage: GithubConnectorStage
|
||||
curr_page: int
|
||||
|
||||
cached_repo_ids: list[int] | None = None
|
||||
cached_repo: SerializedRepository | None = None
|
||||
|
||||
|
||||
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
self.include_issues = include_issues
|
||||
self.github_client: Github | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# defaults to 30 items per page, can be set to as high as 100
|
||||
self.github_client = (
|
||||
Github(
|
||||
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
|
||||
credentials["github_access_token"],
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(credentials["github_access_token"])
|
||||
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -217,85 +237,212 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
self,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=checkpoint.cached_repo_ids[0],
|
||||
headers=repos[0].raw_headers,
|
||||
raw_data=repos[0].raw_data,
|
||||
)
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
# save checkpoint with repo ids retrieved
|
||||
return checkpoint
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try direct access to a known attribute name first
|
||||
if hasattr(self.github_client, "_requester"):
|
||||
requester = self.github_client._requester
|
||||
elif hasattr(self.github_client, "_Github__requester"):
|
||||
requester = self.github_client._Github__requester
|
||||
else:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
repo = checkpoint.cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = checkpoint.cached_repo.id
|
||||
repo = self.github_client.get_repo(repo_id)
|
||||
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
pull_requests, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_prs = False
|
||||
for pr in pr_batch:
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
yield from doc_batch
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
continue
|
||||
try:
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||
if not done_with_prs and len(pr_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# prs to get, we move on to issues
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch = []
|
||||
issue_batch = _get_batch_rate_limited(
|
||||
issues, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
for issue in cast(list[Issue], issue_batch):
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any issues on the page, yield them and return the checkpoint
|
||||
if not done_with_issues and len(issue_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# issues to get, we move on to the next repo
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||
if checkpoint.cached_repo_ids:
|
||||
next_id = checkpoint.cached_repo_ids.pop()
|
||||
next_repo = self.github_client.get_repo(next_id)
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=next_id,
|
||||
headers=next_repo.raw_headers,
|
||||
raw_data=next_repo.raw_data,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
||||
# Could be due to delayed processing on GitHub side
|
||||
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
||||
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
||||
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
if adjusted_start_datetime < epoch:
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
@@ -397,6 +544,16 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -406,7 +563,9 @@ if __name__ == "__main__":
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, time.time(), connector.build_dummy_checkpoint()
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,5 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.section_extraction import get_document_sections
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -27,6 +30,7 @@ from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -73,9 +77,30 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _extract_sections_basic(
|
||||
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
@@ -83,31 +108,17 @@ def _extract_sections_basic(
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
@@ -116,98 +127,112 @@ def _extract_sections_basic(
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
response_call = lazy_eval(lambda: download_request(service, file_id))
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response_call().decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response_call()))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response_call(),
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {file_name}: {e}")
|
||||
return []
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
if mime_type in [
|
||||
"application/vnd.google-apps.video",
|
||||
"application/vnd.google-apps.audio",
|
||||
"application/zip",
|
||||
]:
|
||||
return []
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
) -> Document | None:
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
doc_id = ""
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
try:
|
||||
# skip shortcuts or folders
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
@@ -215,13 +240,11 @@ def convert_drive_item_to_document(
|
||||
return None
|
||||
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
# Try to get sections using the advanced method first
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
# get_document_sections is the advanced approach for Google Docs
|
||||
doc_sections = get_document_sections(
|
||||
docs_service=docs_service, doc_id=file.get("id", "")
|
||||
docs_service=docs_service(), doc_id=file.get("id", "")
|
||||
)
|
||||
if doc_sections:
|
||||
sections = cast(list[TextSection | ImageSection], doc_sections)
|
||||
@@ -230,9 +253,24 @@ def convert_drive_item_to_document(
|
||||
f"Error in advanced parsing: {e}. Falling back to basic extraction."
|
||||
)
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logger.warning(
|
||||
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
sections = _download_and_extract_sections_basic(
|
||||
file, drive_service(), allow_images
|
||||
)
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
@@ -257,8 +295,19 @@ def convert_drive_item_to_document(
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file {file.get('name')}: {e}")
|
||||
return None
|
||||
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
|
||||
logger.exception(error_str)
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=sections[0].link
|
||||
if sections
|
||||
else None, # TODO: see if this is the best way to get a link
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.models import RetrievedDriveFile
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.google_utils import GoogleFields
|
||||
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
@@ -31,11 +37,13 @@ def _generate_time_range_filter(
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime >= '{time_start}'"
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += (
|
||||
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
|
||||
)
|
||||
if end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime <= '{time_stop}'"
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
@@ -66,9 +74,9 @@ def _get_folders_in_parent(
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
@@ -83,6 +91,7 @@ def _get_files_in_parent(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
**({} if is_slim else {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}),
|
||||
):
|
||||
yield file
|
||||
|
||||
@@ -90,30 +99,50 @@ def _get_files_in_parent(
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
is_slim: bool,
|
||||
user_email: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
if parent_id in traversed_parent_ids:
|
||||
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
|
||||
return
|
||||
|
||||
found_files = False
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
found_files = True
|
||||
yield file
|
||||
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logger.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
file = {}
|
||||
try:
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
found_files = True
|
||||
logger.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files in parent {parent_id}: {e}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
error=e,
|
||||
)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
else:
|
||||
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
@@ -123,6 +152,8 @@ def crawl_folders_for_files(
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
is_slim=is_slim,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
@@ -133,16 +164,19 @@ def crawl_folders_for_files(
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
@@ -155,15 +189,13 @@ def get_files_in_shared_drive(
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(drive_id)
|
||||
|
||||
# Get all files in the shared drive
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
@@ -173,21 +205,33 @@ def get_files_in_shared_drive(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
)
|
||||
**kwargs,
|
||||
):
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
service: Any,
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
include_shared_with_me: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
folder_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
@@ -196,7 +240,7 @@ def get_all_files_in_my_drive(
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
update_traversed_ids_func(file[GoogleFields.ID])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
@@ -204,27 +248,34 @@ def get_all_files_in_my_drive(
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: Any,
|
||||
service: GoogleDriveService,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
is_slim: bool = False,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
should_get_all = (
|
||||
include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
)
|
||||
@@ -243,11 +294,13 @@ def get_all_files_for_oauth(
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -255,4 +308,8 @@ def get_all_files_for_oauth(
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return service.files().get(fileId="root", fields="id").execute()["id"]
|
||||
return (
|
||||
service.files()
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import field_serializer
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
@@ -20,3 +29,128 @@ class GDriveMimeType(str, Enum):
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class StageCompletion(BaseModel):
|
||||
"""
|
||||
Describes the point in the retrieval+indexing process that the
|
||||
connector is at. completed_until is the timestamp of the latest
|
||||
file that has been retrieved or error that has been yielded.
|
||||
Optional fields are used for retrieval stages that need more information
|
||||
for resuming than just the timestamp of the latest file.
|
||||
"""
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
completed_until_parent_id: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
|
||||
def update(
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
completed_until_parent_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.completed_until_parent_id = completed_until_parent_id
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
"""
|
||||
Describes a file that has been retrieved from google drive.
|
||||
user_email is the email of the user that the file was retrieved
|
||||
by impersonating. If an error worthy of being reported is encountered,
|
||||
error should be set and later propagated as a ConnectorFailure.
|
||||
"""
|
||||
|
||||
# The stage at which this file was retrieved
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The file that was retrieved
|
||||
drive_file: GoogleDriveFileType
|
||||
|
||||
# The email of the user that the file was retrieved by impersonating
|
||||
user_email: str
|
||||
|
||||
# The id of the parent folder or drive of the file
|
||||
parent_id: str | None = None
|
||||
|
||||
# Any unexpected error that occurred while retrieving the file.
|
||||
# In particular, this is not used for 403/404 errors, which are expected
|
||||
# in the context of impersonating all the users to try to retrieve all
|
||||
# files from all their Drives and Folders.
|
||||
error: Exception | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_folder_and_drive_ids: set[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# has finished yielding all values from the previous stage.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per user email.
|
||||
# StageCompletion is used to track the completion of each stage, but the
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(
|
||||
self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any
|
||||
) -> dict[str, StageCompletion]:
|
||||
return completion_map._dict
|
||||
|
||||
@field_validator("completion_map", mode="before")
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict(
|
||||
{k: StageCompletion.model_validate(v) for k, v in v.items()}
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
@@ -16,20 +17,37 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
#
|
||||
# NOTE: We previously tried to combat this here by adding a very
|
||||
# long retry period (~20 minutes of trying, one request a minute.)
|
||||
# This is no longer necessary due to checkpointing.
|
||||
add_retries = retry_builder(tries=5, max_delay=10)
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 10
|
||||
max_attempts = 6
|
||||
attempt = 1
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 1. Other things are also requesting from the Drive/Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
@@ -90,11 +108,11 @@ def execute_paginated_retrieval(
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
next_page_token = ""
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
while next_page_token is not None:
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
@@ -117,7 +135,7 @@ def execute_paginated_retrieval(
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
next_page_token = results.get("nextPageToken")
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
|
||||
4
backend/onyx/connectors/highspot/__init__.py
Normal file
4
backend/onyx/connectors/highspot/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Highspot connector package for Onyx.
|
||||
Enables integration with Highspot's knowledge base.
|
||||
"""
|
||||
280
backend/onyx/connectors/highspot/client.py
Normal file
280
backend/onyx/connectors/highspot/client.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.exceptions import RequestException
|
||||
from requests.exceptions import Timeout
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HighspotClientError(Exception):
|
||||
"""Base exception for Highspot API client errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class HighspotAuthenticationError(HighspotClientError):
|
||||
"""Exception raised for authentication errors."""
|
||||
|
||||
|
||||
class HighspotRateLimitError(HighspotClientError):
|
||||
"""Exception raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[str] = None):
|
||||
self.retry_after = retry_after
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class HighspotClient:
|
||||
"""
|
||||
Client for interacting with the Highspot API.
|
||||
|
||||
Uses basic authentication with provided key (username) and secret (password).
|
||||
Implements retry logic, error handling, and connection pooling.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api-su2.highspot.com/v1.0/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
base_url: str = BASE_URL,
|
||||
timeout: int = 30,
|
||||
max_retries: int = 3,
|
||||
backoff_factor: float = 0.5,
|
||||
status_forcelist: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Highspot API client.
|
||||
|
||||
Args:
|
||||
key: API key (used as username)
|
||||
secret: API secret (used as password)
|
||||
base_url: Base URL for the Highspot API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
backoff_factor: Backoff factor for retries
|
||||
status_forcelist: HTTP status codes to retry on
|
||||
"""
|
||||
if not key or not secret:
|
||||
raise ValueError("API key and secret are required")
|
||||
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.base_url = base_url.rstrip("/") + "/"
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retry logic
|
||||
self.session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=max_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
status_forcelist=status_forcelist or [429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Set up authentication
|
||||
self._setup_auth()
|
||||
|
||||
def _setup_auth(self) -> None:
|
||||
"""Set up basic authentication for the session."""
|
||||
auth = f"{self.key}:{self.secret}"
|
||||
encoded_auth = base64.b64encode(auth.encode()).decode()
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Basic {encoded_auth}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make a request to the Highspot API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint
|
||||
params: URL parameters
|
||||
data: Form data
|
||||
json_data: JSON data
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
API response as a dictionary
|
||||
|
||||
Raises:
|
||||
HighspotClientError: On API errors
|
||||
HighspotAuthenticationError: On authentication errors
|
||||
HighspotRateLimitError: On rate limiting
|
||||
requests.exceptions.RequestException: On request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
request_headers = {}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
try:
|
||||
logger.debug(f"Making {method} request to {url}")
|
||||
response = self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.content and response.content.strip():
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code
|
||||
error_msg = str(e)
|
||||
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
if isinstance(error_data, dict):
|
||||
error_msg = error_data.get("message", str(e))
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
if status_code == 401:
|
||||
raise HighspotAuthenticationError(f"Authentication failed: {error_msg}")
|
||||
elif status_code == 429:
|
||||
retry_after = e.response.headers.get("Retry-After")
|
||||
raise HighspotRateLimitError(
|
||||
f"Rate limit exceeded: {error_msg}", retry_after=retry_after
|
||||
)
|
||||
else:
|
||||
raise HighspotClientError(
|
||||
f"API error {status_code}: {error_msg}", status_code=status_code
|
||||
)
|
||||
|
||||
except Timeout:
|
||||
raise HighspotClientError("Request timed out")
|
||||
except RequestException as e:
|
||||
raise HighspotClientError(f"Request failed: {str(e)}")
|
||||
|
||||
def get_spots(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all available spots.
|
||||
|
||||
Returns:
|
||||
List of spots with their names and IDs
|
||||
"""
|
||||
params = {"right": "view"}
|
||||
response = self._make_request("GET", "spots", params=params)
|
||||
logger.info(f"Received {response} spots")
|
||||
total_counts = response.get("counts_total")
|
||||
# Fix comparison to handle None value
|
||||
if total_counts is not None and total_counts > 0:
|
||||
return response.get("collection", [])
|
||||
return []
|
||||
|
||||
def get_spot(self, spot_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get details for a specific spot.
|
||||
|
||||
Args:
|
||||
spot_id: ID of the spot
|
||||
|
||||
Returns:
|
||||
Spot details
|
||||
"""
|
||||
if not spot_id:
|
||||
raise ValueError("spot_id is required")
|
||||
return self._make_request("GET", f"spots/{spot_id}")
|
||||
|
||||
def get_spot_items(
|
||||
self, spot_id: str, offset: int = 0, page_size: int = 100
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get items in a specific spot.
|
||||
|
||||
Args:
|
||||
spot_id: ID of the spot
|
||||
offset: offset number
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
Items in the spot
|
||||
"""
|
||||
if not spot_id:
|
||||
raise ValueError("spot_id is required")
|
||||
|
||||
params = {"spot": spot_id, "start": offset, "limit": page_size}
|
||||
return self._make_request("GET", "items", params=params)
|
||||
|
||||
def get_item(self, item_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get details for a specific item.
|
||||
|
||||
Args:
|
||||
item_id: ID of the item
|
||||
|
||||
Returns:
|
||||
Item details
|
||||
"""
|
||||
if not item_id:
|
||||
raise ValueError("item_id is required")
|
||||
return self._make_request("GET", f"items/{item_id}")
|
||||
|
||||
def get_item_content(self, item_id: str) -> bytes:
|
||||
"""
|
||||
Get the raw content of an item.
|
||||
|
||||
Args:
|
||||
item_id: ID of the item
|
||||
|
||||
Returns:
|
||||
Raw content bytes
|
||||
"""
|
||||
if not item_id:
|
||||
raise ValueError("item_id is required")
|
||||
|
||||
url = urljoin(self.base_url, f"items/{item_id}/content")
|
||||
response = self.session.get(url, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""
|
||||
Check if the API is accessible and credentials are valid.
|
||||
|
||||
Returns:
|
||||
True if API is accessible, False otherwise
|
||||
"""
|
||||
try:
|
||||
self._make_request("GET", "spots", params={"limit": 1})
|
||||
return True
|
||||
except (HighspotClientError, HighspotAuthenticationError):
|
||||
return False
|
||||
431
backend/onyx/connectors/highspot/connector.py
Normal file
431
backend/onyx/connectors/highspot/connector.py
Normal file
@@ -0,0 +1,431 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.highspot.client import HighspotClient
|
||||
from onyx.connectors.highspot.client import HighspotClientError
|
||||
from onyx.connectors.highspot.utils import scrape_url_content
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""
|
||||
Connector for loading data from Highspot.
|
||||
|
||||
Retrieves content from specified spots using the Highspot API.
|
||||
If no spots are specified, retrieves content from all available spots.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spot_names: List[str] = [],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
"""
|
||||
Initialize the Highspot connector.
|
||||
|
||||
Args:
|
||||
spot_names: List of spot names to retrieve content from (if empty, gets all spots)
|
||||
batch_size: Number of items to retrieve in each batch
|
||||
"""
|
||||
self.spot_names = spot_names
|
||||
self.batch_size = batch_size
|
||||
self._client: Optional[HighspotClient] = None
|
||||
self._spot_id_map: Dict[str, str] = {} # Maps spot names to spot IDs
|
||||
self._all_spots_fetched = False
|
||||
self.highspot_url: Optional[str] = None
|
||||
self.key: Optional[str] = None
|
||||
self.secret: Optional[str] = None
|
||||
|
||||
@property
|
||||
def client(self) -> HighspotClient:
|
||||
if self._client is None:
|
||||
if not self.key or not self.secret:
|
||||
raise ConnectorMissingCredentialError("Highspot")
|
||||
# Ensure highspot_url is a string, use default if None
|
||||
base_url = (
|
||||
self.highspot_url
|
||||
if self.highspot_url is not None
|
||||
else HighspotClient.BASE_URL
|
||||
)
|
||||
self._client = HighspotClient(self.key, self.secret, base_url=base_url)
|
||||
return self._client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
logger.info("Loading Highspot credentials")
|
||||
self.highspot_url = credentials.get("highspot_url")
|
||||
self.key = credentials.get("highspot_key")
|
||||
self.secret = credentials.get("highspot_secret")
|
||||
return None
|
||||
|
||||
def _populate_spot_id_map(self) -> None:
|
||||
"""
|
||||
Populate the spot ID map with all available spots.
|
||||
Keys are stored as lowercase for case-insensitive lookups.
|
||||
"""
|
||||
spots = self.client.get_spots()
|
||||
for spot in spots:
|
||||
if "title" in spot and "id" in spot:
|
||||
spot_name = spot["title"]
|
||||
self._spot_id_map[spot_name.lower()] = spot["id"]
|
||||
|
||||
self._all_spots_fetched = True
|
||||
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
|
||||
|
||||
def _get_all_spot_names(self) -> List[str]:
|
||||
"""
|
||||
Retrieve all available spot names.
|
||||
|
||||
Returns:
|
||||
List of all spot names
|
||||
"""
|
||||
if not self._all_spots_fetched:
|
||||
self._populate_spot_id_map()
|
||||
|
||||
return [spot_name for spot_name in self._spot_id_map.keys()]
|
||||
|
||||
def _get_spot_id_from_name(self, spot_name: str) -> str:
|
||||
"""
|
||||
Get spot ID from a spot name.
|
||||
|
||||
Args:
|
||||
spot_name: Name of the spot
|
||||
|
||||
Returns:
|
||||
ID of the spot
|
||||
|
||||
Raises:
|
||||
ValueError: If spot name is not found
|
||||
"""
|
||||
if not self._all_spots_fetched:
|
||||
self._populate_spot_id_map()
|
||||
|
||||
spot_name_lower = spot_name.lower()
|
||||
if spot_name_lower not in self._spot_id_map:
|
||||
raise ValueError(f"Spot '{spot_name}' not found")
|
||||
|
||||
return self._spot_id_map[spot_name_lower]
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Load content from configured spots in Highspot.
|
||||
If no spots are configured, loads from all spots.
|
||||
|
||||
Yields:
|
||||
Batches of Document objects
|
||||
"""
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Poll Highspot for content updated since the start time.
|
||||
|
||||
Args:
|
||||
start: Start time as seconds since Unix epoch
|
||||
end: End time as seconds since Unix epoch
|
||||
|
||||
Yields:
|
||||
Batches of Document objects
|
||||
"""
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
# If no spots specified, get all spots
|
||||
spot_names_to_process = self.spot_names
|
||||
if not spot_names_to_process:
|
||||
spot_names_to_process = self._get_all_spot_names()
|
||||
logger.info(
|
||||
f"No spots specified, using all {len(spot_names_to_process)} available spots"
|
||||
)
|
||||
|
||||
for spot_name in spot_names_to_process:
|
||||
try:
|
||||
spot_id = self._get_spot_id_from_name(spot_name)
|
||||
if spot_id is None:
|
||||
logger.warning(f"Spot ID not found for spot {spot_name}")
|
||||
continue
|
||||
offset = 0
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
logger.info(
|
||||
f"Retrieving items from spot {spot_name}, offset {offset}"
|
||||
)
|
||||
response = self.client.get_spot_items(
|
||||
spot_id=spot_id, offset=offset, page_size=self.batch_size
|
||||
)
|
||||
items = response.get("collection", [])
|
||||
logger.info(f"Received Items: {items}")
|
||||
if not items:
|
||||
has_more = False
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
try:
|
||||
item_id = item.get("id")
|
||||
if not item_id:
|
||||
logger.warning("Item without ID found, skipping")
|
||||
continue
|
||||
|
||||
item_details = self.client.get_item(item_id)
|
||||
if not item_details:
|
||||
logger.warning(
|
||||
f"Item {item_id} details not found, skipping"
|
||||
)
|
||||
continue
|
||||
# Apply time filter if specified
|
||||
if start or end:
|
||||
updated_at = item_details.get("date_updated")
|
||||
if updated_at:
|
||||
# Convert to datetime for comparison
|
||||
try:
|
||||
updated_time = datetime.fromisoformat(
|
||||
updated_at.replace("Z", "+00:00")
|
||||
)
|
||||
if (
|
||||
start and updated_time.timestamp() < start
|
||||
) or (end and updated_time.timestamp() > end):
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
# Skip if date cannot be parsed
|
||||
logger.warning(
|
||||
f"Invalid date format for item {item_id}: {updated_at}"
|
||||
)
|
||||
continue
|
||||
|
||||
content = self._get_item_content(item_details)
|
||||
title = item_details.get("title", "")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=f"HIGHSPOT_{item_id}",
|
||||
sections=[
|
||||
TextSection(
|
||||
link=item_details.get(
|
||||
"url",
|
||||
f"https://www.highspot.com/items/{item_id}",
|
||||
),
|
||||
text=content,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.HIGHSPOT,
|
||||
semantic_identifier=title,
|
||||
metadata={
|
||||
"spot_name": spot_name,
|
||||
"type": item_details.get("content_type", ""),
|
||||
"created_at": item_details.get(
|
||||
"date_added", ""
|
||||
),
|
||||
"author": item_details.get("author", ""),
|
||||
"language": item_details.get("language", ""),
|
||||
"can_download": str(
|
||||
item_details.get("can_download", False)
|
||||
),
|
||||
},
|
||||
doc_updated_at=item_details.get("date_updated"),
|
||||
)
|
||||
)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
except HighspotClientError as e:
|
||||
item_id = "ID" if not item_id else item_id
|
||||
logger.error(f"Error retrieving item {item_id}: {str(e)}")
|
||||
|
||||
has_more = len(items) >= self.batch_size
|
||||
offset += self.batch_size
|
||||
|
||||
except (HighspotClientError, ValueError) as e:
|
||||
logger.error(f"Error processing spot {spot_name}: {str(e)}")
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def _get_item_content(self, item_details: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Get the text content of an item.
|
||||
|
||||
Args:
|
||||
item_details: Item details from the API
|
||||
|
||||
Returns:
|
||||
Text content of the item
|
||||
"""
|
||||
item_id = item_details.get("id", "")
|
||||
content_name = item_details.get("content_name", "")
|
||||
is_valid_format = content_name and "." in content_name
|
||||
file_extension = content_name.split(".")[-1].lower() if is_valid_format else ""
|
||||
file_extension = "." + file_extension if file_extension else ""
|
||||
can_download = item_details.get("can_download", False)
|
||||
content_type = item_details.get("content_type", "")
|
||||
|
||||
# Extract title and description once at the beginning
|
||||
title, description = self._extract_title_and_description(item_details)
|
||||
default_content = f"{title}\n{description}"
|
||||
logger.info(f"Processing item {item_id} with extension {file_extension}")
|
||||
|
||||
try:
|
||||
if content_type == "WebLink":
|
||||
url = item_details.get("url")
|
||||
if not url:
|
||||
return default_content
|
||||
content = scrape_url_content(url, True)
|
||||
return content if content else default_content
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
and can_download
|
||||
):
|
||||
# For documents, try to get the text content
|
||||
if not item_id: # Ensure item_id is defined
|
||||
return default_content
|
||||
|
||||
content_response = self.client.get_item_content(item_id)
|
||||
# Process and extract text from binary content based on type
|
||||
if content_response:
|
||||
text_content = extract_file_text(
|
||||
BytesIO(content_response), content_name
|
||||
)
|
||||
return text_content
|
||||
return default_content
|
||||
|
||||
else:
|
||||
return default_content
|
||||
|
||||
except HighspotClientError as e:
|
||||
# Use item_id safely in the warning message
|
||||
error_context = f"item {item_id}" if item_id else "item"
|
||||
logger.warning(f"Could not retrieve content for {error_context}: {str(e)}")
|
||||
return ""
|
||||
|
||||
def _extract_title_and_description(
|
||||
self, item_details: Dict[str, Any]
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Extract the title and description from item details.
|
||||
|
||||
Args:
|
||||
item_details: Item details from the API
|
||||
|
||||
Returns:
|
||||
Tuple of title and description
|
||||
"""
|
||||
title = item_details.get("title", "")
|
||||
description = item_details.get("description", "")
|
||||
return title, description
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Retrieve all document IDs from the configured spots.
|
||||
If no spots are configured, retrieves from all spots.
|
||||
|
||||
Args:
|
||||
start: Optional start time filter
|
||||
end: Optional end time filter
|
||||
callback: Optional indexing heartbeat callback
|
||||
|
||||
Yields:
|
||||
Batches of SlimDocument objects
|
||||
"""
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
|
||||
# If no spots specified, get all spots
|
||||
spot_names_to_process = self.spot_names
|
||||
if not spot_names_to_process:
|
||||
spot_names_to_process = self._get_all_spot_names()
|
||||
logger.info(
|
||||
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
|
||||
)
|
||||
|
||||
for spot_name in spot_names_to_process:
|
||||
try:
|
||||
spot_id = self._get_spot_id_from_name(spot_name)
|
||||
offset = 0
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
logger.info(
|
||||
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
|
||||
)
|
||||
response = self.client.get_spot_items(
|
||||
spot_id=spot_id, offset=offset, page_size=self.batch_size
|
||||
)
|
||||
|
||||
items = response.get("collection", [])
|
||||
if not items:
|
||||
has_more = False
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
item_id = item.get("id")
|
||||
if not item_id:
|
||||
continue
|
||||
|
||||
slim_doc_batch.append(SlimDocument(id=f"HIGHSPOT_{item_id}"))
|
||||
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
|
||||
has_more = len(items) >= self.batch_size
|
||||
offset += self.batch_size
|
||||
|
||||
except (HighspotClientError, ValueError) as e:
|
||||
logger.error(
|
||||
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
|
||||
)
|
||||
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_credentials(self) -> bool:
|
||||
"""
|
||||
Validate that the provided credentials can access the Highspot API.
|
||||
|
||||
Returns:
|
||||
True if credentials are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
return self.client.health_check()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate credentials: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
spot_names: List[str] = []
|
||||
connector = HighspotConnector(spot_names)
|
||||
credentials = {"highspot_key": "", "highspot_secret": ""}
|
||||
connector.load_credentials(credentials=credentials)
|
||||
for doc in connector.load_from_state():
|
||||
print(doc)
|
||||
122
backend/onyx/connectors/highspot/utils.py
Normal file
122
backend/onyx/connectors/highspot/utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Constants
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
DEFAULT_TIMEOUT = 60000 # 60 seconds
|
||||
|
||||
|
||||
def scrape_url_content(
|
||||
url: str, scroll_before_scraping: bool = False, timeout_ms: int = DEFAULT_TIMEOUT
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Scrapes content from a given URL and returns the cleaned text.
|
||||
|
||||
Args:
|
||||
url: The URL to scrape
|
||||
scroll_before_scraping: Whether to scroll through the page to load lazy content
|
||||
timeout_ms: Timeout in milliseconds for page navigation and loading
|
||||
|
||||
Returns:
|
||||
The cleaned text content of the page or None if scraping fails
|
||||
"""
|
||||
playwright = None
|
||||
browser = None
|
||||
try:
|
||||
validate_url(url)
|
||||
playwright = sync_playwright().start()
|
||||
browser = playwright.chromium.launch(headless=True)
|
||||
context = browser.new_context()
|
||||
page = context.new_page()
|
||||
|
||||
logger.info(f"Navigating to URL: {url}")
|
||||
try:
|
||||
page.goto(url, timeout=timeout_ms)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to navigate to {url}: {str(e)}")
|
||||
return None
|
||||
|
||||
if scroll_before_scraping:
|
||||
logger.debug("Scrolling page to load lazy content")
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=timeout_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Network idle wait timed out: {str(e)}")
|
||||
break
|
||||
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
if new_height == previous_height:
|
||||
break
|
||||
previous_height = new_height
|
||||
scroll_attempts += 1
|
||||
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
parsed_html = web_html_cleanup(soup)
|
||||
|
||||
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
|
||||
logger.debug("JavaScript disabled message detected, checking iframes")
|
||||
try:
|
||||
iframe_count = page.frame_locator("iframe").locator("html").count()
|
||||
if iframe_count > 0:
|
||||
iframe_texts = (
|
||||
page.frame_locator("iframe").locator("html").all_inner_texts()
|
||||
)
|
||||
iframe_content = "\n".join(iframe_texts)
|
||||
|
||||
if len(parsed_html.cleaned_text) < 700:
|
||||
parsed_html.cleaned_text = iframe_content
|
||||
else:
|
||||
parsed_html.cleaned_text += "\n" + iframe_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing iframes: {str(e)}")
|
||||
|
||||
return parsed_html.cleaned_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scraping URL {url}: {str(e)}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
if browser:
|
||||
try:
|
||||
browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing browser: {str(e)}")
|
||||
if playwright:
|
||||
try:
|
||||
playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error stopping playwright: {str(e)}")
|
||||
|
||||
|
||||
def validate_url(url: str) -> None:
|
||||
"""
|
||||
Validates that a URL is properly formatted.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is not valid
|
||||
"""
|
||||
parse = urlparse(url)
|
||||
if parse.scheme != "http" and parse.scheme != "https":
|
||||
raise ValueError("URL must be of scheme https?://")
|
||||
|
||||
if not parse.hostname:
|
||||
raise ValueError("URL must include a hostname")
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -19,10 +20,11 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
@@ -57,6 +59,14 @@ class BaseConnector(abc.ABC):
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
@@ -74,6 +84,8 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors can retrieve just the ids and
|
||||
# permission syncing information for connected documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
@@ -186,14 +198,17 @@ class EventConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector):
|
||||
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector[CT]):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
@@ -214,3 +229,12 @@ class CheckpointConnector(BaseConnector):
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
|
||||
"""Validate the checkpoint json and return the checkpoint object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
@@ -15,14 +16,18 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class MockConnectorCheckpoint(ConnectorCheckpoint):
|
||||
last_document_id: str | None = None
|
||||
|
||||
|
||||
class SingleConnectorYield(BaseModel):
|
||||
documents: list[Document]
|
||||
checkpoint: ConnectorCheckpoint
|
||||
checkpoint: MockConnectorCheckpoint
|
||||
failures: list[ConnectorFailure]
|
||||
unhandled_exception: str | None = None
|
||||
|
||||
|
||||
class MockConnector(CheckpointConnector):
|
||||
class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
mock_server_host: str,
|
||||
@@ -48,7 +53,7 @@ class MockConnector(CheckpointConnector):
|
||||
def _get_mock_server_url(self, endpoint: str) -> str:
|
||||
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
|
||||
|
||||
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
|
||||
def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None:
|
||||
response = self.client.post(
|
||||
self._get_mock_server_url("add-checkpoint"),
|
||||
json=checkpoint.model_dump(mode="json"),
|
||||
@@ -59,8 +64,8 @@ class MockConnector(CheckpointConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint: MockConnectorCheckpoint,
|
||||
) -> CheckpointOutput[MockConnectorCheckpoint]:
|
||||
if self.connector_yields is None:
|
||||
raise ValueError("No connector yields configured")
|
||||
|
||||
@@ -84,3 +89,13 @@ class MockConnector(CheckpointConnector):
|
||||
yield failure
|
||||
|
||||
return current_yield.checkpoint
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> MockConnectorCheckpoint:
|
||||
return MockConnectorCheckpoint(
|
||||
has_more=True,
|
||||
last_document_id=None,
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint:
|
||||
return MockConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -164,6 +163,9 @@ class DocumentBase(BaseModel):
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
@@ -232,21 +234,16 @@ class IndexAttemptMetadata(BaseModel):
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the checkpoint, with truncation for large checkpoint content."""
|
||||
MAX_CHECKPOINT_CONTENT_CHARS = 1000
|
||||
|
||||
content_str = json.dumps(self.checkpoint_content)
|
||||
content_str = self.model_dump_json()
|
||||
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
|
||||
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
|
||||
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
|
||||
return content_str
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
@@ -25,6 +25,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -38,8 +39,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionPage:
|
||||
class NotionPage(BaseModel):
|
||||
"""Represents a Notion Page object"""
|
||||
|
||||
id: str
|
||||
@@ -49,17 +49,10 @@ class NotionPage:
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
database_name: str | None # Only applicable to the database type page (wiki)
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
names = set([f.name for f in fields(self)])
|
||||
for k, v in kwargs.items():
|
||||
if k in names:
|
||||
setattr(self, k, v)
|
||||
database_name: str | None = None # Only applicable to the database type page (wiki)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionBlock:
|
||||
class NotionBlock(BaseModel):
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
id: str # Used for the URL
|
||||
@@ -69,20 +62,13 @@ class NotionBlock:
|
||||
prefix: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotionSearchResponse:
|
||||
class NotionSearchResponse(BaseModel):
|
||||
"""Represents the response from the Notion Search API"""
|
||||
|
||||
results: list[dict[str, Any]]
|
||||
next_cursor: Optional[str]
|
||||
has_more: bool = False
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
names = set([f.name for f in fields(self)])
|
||||
for k, v in kwargs.items():
|
||||
if k in names:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Notion Page connector that reads all Notion pages
|
||||
@@ -95,7 +81,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
root_page_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
@@ -464,23 +450,53 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
|
||||
if not page_blocks:
|
||||
continue
|
||||
# okay to mark here since there's no way for this to not succeed
|
||||
# without a critical failure
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
page_title = (
|
||||
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
|
||||
)
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logger.warning(
|
||||
f"No blocks OR title found for page with ID '{page.id}'. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(f"No blocks found for page with ID '{page.id}'")
|
||||
"""
|
||||
Something like:
|
||||
|
||||
TITLE
|
||||
|
||||
PROP1: PROP1_VALUE
|
||||
PROP2: PROP2_VALUE
|
||||
"""
|
||||
text = page_title
|
||||
if page.properties:
|
||||
text += "\n\n" + "\n".join(
|
||||
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||
)
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
]
|
||||
|
||||
yield (
|
||||
Document(
|
||||
id=page.id,
|
||||
sections=[
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
],
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
@@ -15,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
@@ -42,121 +45,112 @@ _JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
|
||||
def _paginate_jql_search(
|
||||
def _perform_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
start = 0
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise Exception(f"Found Jira object not of type Issue: {issue}")
|
||||
|
||||
if len(issues) < max_results:
|
||||
break
|
||||
|
||||
start += max_results
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
|
||||
|
||||
|
||||
def fetch_jira_issues_batch(
|
||||
def process_jira_issue(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
batch_size: int,
|
||||
issue: Issue,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> Iterable[Document]:
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=jira_client,
|
||||
jql=jql,
|
||||
max_results=batch_size,
|
||||
):
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
) -> Document | None:
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
return None
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["labels"] = labels
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@@ -200,33 +194,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_jql_query(self) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set"""
|
||||
if self.jira_project:
|
||||
return f"project = {self.quoted_jira_project}"
|
||||
return "" # Empty string means all accessible projects
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = self._get_jql_query()
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
yield document_batch
|
||||
|
||||
def poll_source(
|
||||
def _get_jql_query(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set and time range"""
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
@@ -234,25 +205,61 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
base_jql = self._get_jql_query()
|
||||
jql = (
|
||||
f"{base_jql} AND " if base_jql else ""
|
||||
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
if self.jira_project:
|
||||
base_jql = f"project = {self.quoted_jira_project}"
|
||||
return f"{base_jql} AND {time_jql}"
|
||||
|
||||
return time_jql
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
current_offset = starting_offset
|
||||
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_FULL_PAGE_SIZE,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
yield document
|
||||
|
||||
yield document_batch
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
current_offset += 1
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint = JiraConnectorCheckpoint(
|
||||
offset=current_offset,
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -260,12 +267,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = self._get_jql_query()
|
||||
jql = self._get_jql_query(start or 0, end or float("inf"))
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
for issue in _perform_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=0,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields="key",
|
||||
):
|
||||
@@ -334,6 +342,16 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint:
|
||||
return JiraConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint:
|
||||
return JiraConnectorCheckpoint(
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -350,5 +368,7 @@ if __name__ == "__main__":
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, float("inf"), JiraConnectorCheckpoint(has_more=True)
|
||||
)
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -10,13 +10,17 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import SLACK_NUM_THREADS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -24,6 +28,8 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -36,15 +42,16 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
@@ -56,8 +63,8 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
class SlackCheckpoint(ConnectorCheckpoint):
|
||||
channel_ids: list[str] | None
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
@@ -412,8 +419,8 @@ def _get_all_doc_ids(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
@@ -421,18 +428,27 @@ def _get_all_doc_ids(
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
message_ts_set.add(message["ts"])
|
||||
|
||||
channel_metadata_list: list[SlimDocument] = []
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield channel_metadata_list
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
class ProcessedSlackMessage(BaseModel):
|
||||
doc: Document | None
|
||||
# if the message is part of a thread, this is the thread_ts
|
||||
# otherwise, this is the message_ts. Either way, will be a unique identifier.
|
||||
# In the future, if the message becomes a thread, then the thread_ts
|
||||
# will be set to the message_ts.
|
||||
thread_or_message_ts: str
|
||||
failure: ConnectorFailure | None
|
||||
|
||||
|
||||
def _process_message(
|
||||
@@ -443,8 +459,9 @@ def _process_message(
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
) -> ProcessedSlackMessage:
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
@@ -460,16 +477,18 @@ def _process_message(
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
@@ -479,7 +498,13 @@ def _process_message(
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
class SlackConnector(
|
||||
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
|
||||
):
|
||||
FAST_TIMEOUT = 1
|
||||
|
||||
MAX_RETRIES = 7 # arbitrarily selected
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -487,21 +512,60 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
|
||||
self.fast_client: WebClient | None = None
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
self.credential_prefix: str | None = None
|
||||
self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
credentials = credentials_provider.get_credentials()
|
||||
tenant_id = credentials_provider.get_tenant_id()
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.credential_prefix = (
|
||||
f"connector:slack:credential_{credentials_provider.get_provider_key()}"
|
||||
)
|
||||
self.delay_lock = f"{self.credential_prefix}:delay_lock"
|
||||
self.delay_key = f"{self.credential_prefix}:delay"
|
||||
|
||||
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
|
||||
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
|
||||
max_retry_count=self.MAX_RETRIES,
|
||||
delay_lock=self.delay_lock,
|
||||
delay_key=self.delay_key,
|
||||
r=self.redis,
|
||||
)
|
||||
custom_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
onyx_rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers)
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -523,8 +587,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
checkpoint: SlackCheckpoint,
|
||||
) -> CheckpointOutput[SlackCheckpoint]:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
@@ -540,49 +604,36 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
|
||||
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
if checkpoint.channel_ids is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
|
||||
if len(filtered_channels) == 0:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
checkpoint.current_channel = filtered_channels[0]
|
||||
checkpoint.has_more = True
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
final_channel_ids = checkpoint.channel_ids
|
||||
channel = checkpoint.current_channel
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
raise ValueError("current_channel key not set in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint.seen_thread_ts)
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
@@ -593,8 +644,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future] = []
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@@ -612,46 +663,46 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
doc, thread_ts, failures = future.result()
|
||||
processed_slack_message = future.result()
|
||||
doc = processed_slack_message.doc
|
||||
thread_or_message_ts = processed_slack_message.thread_or_message_ts
|
||||
failure = processed_slack_message.failure
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
if thread_or_message_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
assert (
|
||||
thread_or_message_ts
|
||||
), "found non-None doc with None thread_or_message_ts"
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint_content["current_channel"] = channel
|
||||
checkpoint.current_channel = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
if channel_id not in checkpoint.channel_completion_map
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
checkpoint.current_channel = new_channel
|
||||
else:
|
||||
checkpoint_content["current_channel"] = None
|
||||
checkpoint.current_channel = None
|
||||
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
checkpoint.has_more = checkpoint.current_channel is not None
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
@@ -675,12 +726,12 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
2. Ensure the bot has enough scope to list channels.
|
||||
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
|
||||
"""
|
||||
if self.client is None:
|
||||
if self.fast_client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
|
||||
try:
|
||||
# 1) Validate connection to workspace
|
||||
auth_response = self.client.auth_test()
|
||||
auth_response = self.fast_client.auth_test()
|
||||
if not auth_response.get("ok", False):
|
||||
error_msg = auth_response.get(
|
||||
"error", "Unknown error from Slack auth_test"
|
||||
@@ -688,7 +739,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||
|
||||
# 2) Minimal test to confirm listing channels works
|
||||
test_resp = self.client.conversations_list(
|
||||
test_resp = self.fast_client.conversations_list(
|
||||
limit=1, types=["public_channel"]
|
||||
)
|
||||
if not test_resp.get("ok", False):
|
||||
@@ -706,29 +757,41 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
)
|
||||
|
||||
# 3) If channels are specified and regex is not enabled, verify each is accessible
|
||||
if self.channels and not self.channel_regex_enabled:
|
||||
accessible_channels = get_channels(
|
||||
client=self.client,
|
||||
exclude_archived=True,
|
||||
get_public=True,
|
||||
get_private=True,
|
||||
)
|
||||
# For quick lookups by name or ID, build a map:
|
||||
accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
# NOTE: removed this for now since it may be too slow for large workspaces which may
|
||||
# have some automations which create a lot of channels (100k+)
|
||||
|
||||
for user_channel in self.channels:
|
||||
if (
|
||||
user_channel not in accessible_channel_names
|
||||
and user_channel not in accessible_channel_ids
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
)
|
||||
# if self.channels and not self.channel_regex_enabled:
|
||||
# accessible_channels = get_channels(
|
||||
# client=self.fast_client,
|
||||
# exclude_archived=True,
|
||||
# get_public=True,
|
||||
# get_private=True,
|
||||
# )
|
||||
# # For quick lookups by name or ID, build a map:
|
||||
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
# for user_channel in self.channels:
|
||||
# if (
|
||||
# user_channel not in accessible_channel_names
|
||||
# and user_channel not in accessible_channel_ids
|
||||
# ):
|
||||
# raise ConnectorValidationError(
|
||||
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
# )
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
if slack_error == "missing_scope":
|
||||
if slack_error == "ratelimited":
|
||||
# Handle rate limiting specifically
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.warning(
|
||||
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
|
||||
"Proceeding with validation, but be aware that connector operations might be throttled."
|
||||
)
|
||||
# Continue validation without failing - the connector is likely valid but just rate limited
|
||||
return
|
||||
elif slack_error == "missing_scope":
|
||||
raise InsufficientPermissionsError(
|
||||
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||
@@ -751,6 +814,20 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> SlackCheckpoint:
|
||||
return SlackCheckpoint(
|
||||
channel_ids=None,
|
||||
channel_completion_map={},
|
||||
current_channel=None,
|
||||
seen_thread_ts=[],
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint:
|
||||
return SlackCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -765,9 +842,11 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
gen = connector.load_from_checkpoint(
|
||||
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
|
||||
)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
|
||||
159
backend/onyx/connectors/slack/onyx_retry_handler.py
Normal file
159
backend/onyx/connectors/slack/onyx_retry_handler.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk.http_retry.handler import RetryHandler
|
||||
from slack_sdk.http_retry.request import HttpRequest
|
||||
from slack_sdk.http_retry.response import HttpResponse
|
||||
from slack_sdk.http_retry.state import RetryState
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
This class uses Redis to share a rate limit among multiple threads.
|
||||
|
||||
Threads that encounter a rate limit will observe the shared delay, increment the
|
||||
shared delay with the retry value, and use the new shared value as a wait interval.
|
||||
|
||||
This has the effect of serializing calls when a rate limit is hit, which is what
|
||||
needs to happens if the server punishes us with additional limiting when we make
|
||||
a call too early. We believe this is what Slack is doing based on empirical
|
||||
observation, meaning we see indefinite hangs if we're too aggressive.
|
||||
|
||||
Another way to do this is just to do exponential backoff. Might be easier?
|
||||
|
||||
Adapted from slack's RateLimitErrorRetryHandler.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 60 # used to serialize access to the retry TTL
|
||||
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
|
||||
|
||||
"""RetryHandler that does retries for rate limited errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retry_count: int,
|
||||
delay_lock: str,
|
||||
delay_key: str,
|
||||
r: Redis,
|
||||
):
|
||||
"""
|
||||
delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key)
|
||||
delay_key: the redis key containing a shared TTL
|
||||
"""
|
||||
super().__init__(max_retry_count=max_retry_count)
|
||||
self._redis: Redis = r
|
||||
self._delay_lock = delay_lock
|
||||
self._delay_key = delay_key
|
||||
|
||||
def _can_retry(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> bool:
|
||||
return response is not None and response.status_code == 429
|
||||
|
||||
def prepare_for_next_attempt(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""It seems this function is responsible for the wait to retry ... aka we
|
||||
actually sleep in this function."""
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
|
||||
if response is None:
|
||||
# NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler.
|
||||
# This reads oddly, as if the caller itself could raise the exception.
|
||||
# We don't have the luxury of changing this.
|
||||
if error:
|
||||
raise error
|
||||
|
||||
return
|
||||
|
||||
state.next_attempt_requested = True # this signals the caller to retry
|
||||
|
||||
# calculate wait duration based on retry-after + some jitter
|
||||
for k in response.headers.keys():
|
||||
if k.lower() == "retry-after":
|
||||
retry_after_header_name = k
|
||||
break
|
||||
|
||||
try:
|
||||
if retry_after_header_name is None:
|
||||
# This situation usually does not arise. Just in case.
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
|
||||
)
|
||||
|
||||
retry_after_value = response.headers.get(retry_after_header_name)
|
||||
if not retry_after_value:
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
|
||||
)
|
||||
|
||||
retry_after_value_int = int(
|
||||
retry_after_value[0]
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = math.ceil(retry_after_value_int + jitter)
|
||||
except ValueError:
|
||||
duration_s += random.random()
|
||||
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
|
||||
)
|
||||
|
||||
ttl_ms: int | None = None
|
||||
|
||||
try:
|
||||
if acquired:
|
||||
# if we can get the lock, then read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
else:
|
||||
# if we can't get the lock, just go ahead.
|
||||
# TODO: if we know our actual parallelism, multiplying by that
|
||||
# would be a pretty good idea
|
||||
ttl_ms_new = int(duration_s * 1000.0)
|
||||
finally:
|
||||
if acquired:
|
||||
lock.release()
|
||||
|
||||
logger.warning(
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
|
||||
f"retry-after={retry_after_value} "
|
||||
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
|
||||
)
|
||||
|
||||
# TODO: would be good to take an event var and sleep in short increments to
|
||||
# allow for a clean exit / exception
|
||||
time.sleep(ttl_ms_new / 1000.0)
|
||||
|
||||
state.increment_current_attempt()
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
@@ -64,71 +63,72 @@ def _make_slack_api_call_paginated(
|
||||
return paginated_call
|
||||
|
||||
|
||||
def make_slack_api_rate_limited(
|
||||
call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as
|
||||
# expected. Do we want to keep this around?
|
||||
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
# def make_slack_api_rate_limited(
|
||||
# call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
# ) -> Callable[..., SlackResponse]:
|
||||
# """Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
response = call(**kwargs)
|
||||
# @wraps(call)
|
||||
# def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
# last_exception = None
|
||||
|
||||
# Check for errors in the response, will raise `SlackApiError`
|
||||
# if anything went wrong
|
||||
response.validate()
|
||||
return response
|
||||
# for _ in range(max_retries):
|
||||
# try:
|
||||
# # Make the API call
|
||||
# response = call(**kwargs)
|
||||
|
||||
except SlackApiError as e:
|
||||
last_exception = e
|
||||
try:
|
||||
error = e.response["error"]
|
||||
except KeyError:
|
||||
error = "unknown error"
|
||||
# # Check for errors in the response, will raise `SlackApiError`
|
||||
# # if anything went wrong
|
||||
# response.validate()
|
||||
# return response
|
||||
|
||||
if error == "ratelimited":
|
||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.info(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# Log internal_error and return the response instead of failing
|
||||
logger.warning(
|
||||
f"Slack call encountered '{error}', skipping and continuing..."
|
||||
)
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
raise
|
||||
# except SlackApiError as e:
|
||||
# last_exception = e
|
||||
# try:
|
||||
# error = e.response["error"]
|
||||
# except KeyError:
|
||||
# error = "unknown error"
|
||||
|
||||
# If the code reaches this point, all retries have been exhausted
|
||||
msg = f"Max retries ({max_retries}) exceeded"
|
||||
if last_exception:
|
||||
raise Exception(msg) from last_exception
|
||||
else:
|
||||
raise Exception(msg)
|
||||
# if error == "ratelimited":
|
||||
# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
# retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
# logger.info(
|
||||
# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
# )
|
||||
# time.sleep(retry_after)
|
||||
# elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# # Log internal_error and return the response instead of failing
|
||||
# logger.warning(
|
||||
# f"Slack call encountered '{error}', skipping and continuing..."
|
||||
# )
|
||||
# return e.response
|
||||
# else:
|
||||
# # Raise the error for non-transient errors
|
||||
# raise
|
||||
|
||||
return rate_limited_call
|
||||
# # If the code reaches this point, all retries have been exhausted
|
||||
# msg = f"Max retries ({max_retries}) exceeded"
|
||||
# if last_exception:
|
||||
# raise Exception(msg) from last_exception
|
||||
# else:
|
||||
# raise Exception(msg)
|
||||
|
||||
# return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
|
||||
return basic_retry_wrapper(call)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(make_slack_api_rate_limited(call))
|
||||
)(**kwargs)
|
||||
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
@@ -142,7 +142,7 @@ def expert_info_from_slack_id(
|
||||
if user_id in user_cache:
|
||||
return user_cache[user_id]
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
response = client.users_info(user=user_id)
|
||||
|
||||
if not response["ok"]:
|
||||
user_cache[user_id] = None
|
||||
@@ -175,9 +175,7 @@ class SlackTextCleaner:
|
||||
def _get_slack_name(self, user_id: str) -> str:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = make_slack_api_rate_limited(self._client.users_info)(
|
||||
user=user_id
|
||||
)
|
||||
response = self._client.users_info(user=user_id)
|
||||
# prefer display name if set, since that is what is shown in Slack
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -26,6 +35,7 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -53,10 +63,22 @@ class ZendeskClient:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
elif (
|
||||
response.status_code == 403
|
||||
and response.json().get("error") == "SupportProductInactive"
|
||||
):
|
||||
return response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
meta: dict[str, Any]
|
||||
has_more: bool
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
@@ -82,11 +104,9 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
)
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
@@ -98,10 +118,30 @@ def _get_articles(
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_article_page(
|
||||
client: ZendeskClient,
|
||||
start_time: int | None = None,
|
||||
after_cursor: str | None = None,
|
||||
page_size: int = MAX_PAGE_SIZE,
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if after_cursor is not None:
|
||||
params["page[after]"] = after_cursor
|
||||
|
||||
data = client.make_request("help_center/articles", params)
|
||||
return ZendeskPageResponse(
|
||||
data=data["articles"],
|
||||
meta=data["meta"],
|
||||
has_more=bool(data["meta"].get("has_more", False)),
|
||||
)
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
@@ -114,9 +154,33 @@ def _get_tickets(
|
||||
break
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
# TODO: maybe these don't need to be their own functions?
|
||||
def _get_tickets_page(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
|
||||
# in my local testing with very few tickets. We'll look into it if this becomes an
|
||||
# issue in larger deployments
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
if data.get("error") == "SupportProductInactive":
|
||||
raise ValueError(
|
||||
"Zendesk Support Product is not active for this account, No tickets to index"
|
||||
)
|
||||
return ZendeskPageResponse(
|
||||
data=data["tickets"],
|
||||
meta={"end_time": data["end_time"]},
|
||||
has_more=not bool(data.get("end_of_stream", False)),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_author(
|
||||
client: ZendeskClient, author_id: str | int
|
||||
) -> BasicExpertInfo | None:
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
# cast to str to avoid issues with zendesk changing their types
|
||||
if not author_id or str(author_id) == "-1":
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -278,13 +342,22 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# We use cursor-based paginated retrieval for articles
|
||||
after_cursor_articles: str | None
|
||||
|
||||
# We use timestamp-based paginated retrieval for tickets
|
||||
next_start_time_tickets: int | None
|
||||
|
||||
cached_author_map: dict[str, BasicExpertInfo] | None
|
||||
cached_content_tags: dict[str, str] | None
|
||||
|
||||
|
||||
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
@@ -304,33 +377,50 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
if checkpoint.cached_content_tags is None:
|
||||
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
|
||||
return checkpoint # save the content tags to the checkpoint
|
||||
self.content_tags = checkpoint.cached_content_tags
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
|
||||
return checkpoint
|
||||
elif self.content_type == "tickets":
|
||||
yield from self._poll_tickets(start)
|
||||
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
|
||||
return checkpoint
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
def _retrieve_articles(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
after_cursor = checkpoint.after_cursor_articles
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
doc_batch = []
|
||||
response = _get_article_page(
|
||||
self.client,
|
||||
start_time=int(start) if start else None,
|
||||
after_cursor=after_cursor,
|
||||
)
|
||||
articles = response.data
|
||||
has_more = response.has_more
|
||||
after_cursor = response.meta.get("after_cursor")
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
@@ -342,66 +432,109 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
try:
|
||||
new_author_map, document = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{article.get('id')}",
|
||||
document_link=article.get("html_url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
doc_batch.append(document)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Sometimes no documents are retrieved, but the cursor
|
||||
# is still updated so the connector makes progress.
|
||||
yield from doc_batch
|
||||
checkpoint.after_cursor_articles = after_cursor
|
||||
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def _retrieve_tickets(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
doc_batch: list[Document] = []
|
||||
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
|
||||
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
|
||||
tickets = ticket_response.data
|
||||
has_more = ticket_response.has_more
|
||||
next_start_time = ticket_response.meta["end_time"]
|
||||
for ticket in tickets:
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
try:
|
||||
new_author_map, document = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{ticket.get('id')}",
|
||||
document_link=ticket.get("url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
yield from doc_batch
|
||||
checkpoint.next_start_time_tickets = next_start_time
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -441,10 +574,51 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
try:
|
||||
_get_article_page(self.client, start_time=0)
|
||||
except HTTPError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.response.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.response.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
elif e.response.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Zendesk resource not found (HTTP 404)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint(
|
||||
after_cursor_articles=None,
|
||||
next_start_time_tickets=None,
|
||||
cached_author_map=None,
|
||||
cached_content_tags=None,
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
connector.load_credentials(
|
||||
@@ -457,6 +631,8 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
one_day_ago, current, connector.build_dummy_checkpoint()
|
||||
)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
|
||||
inference_settings = InferenceSettings.from_db_model(search_settings)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
return cls(**inference_settings.dict(), **indexing_setting.dict())
|
||||
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
|
||||
|
||||
|
||||
class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
@@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
# Whether switching to this model requires re-indexing
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
@@ -102,6 +105,8 @@ class BaseFilters(BaseModel):
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
@@ -218,6 +223,8 @@ class InferenceChunk(BaseChunk):
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
|
||||
match_highlights: list[str]
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
|
||||
@@ -158,6 +158,47 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
def get_ordering_only_chunks(
|
||||
self,
|
||||
query: str,
|
||||
user_file_ids: list[int] | None = None,
|
||||
user_folder_ids: list[int] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Optimized method that only retrieves chunks for ordering purposes.
|
||||
Skips all extra processing and uses minimal configuration to speed up retrieval.
|
||||
"""
|
||||
logger.info("Fast path: Using optimized chunk retrieval for ordering-only mode")
|
||||
|
||||
# Create minimal filters with just user file/folder IDs
|
||||
filters = IndexFilters(
|
||||
user_file_ids=user_file_ids or [],
|
||||
user_folder_ids=user_folder_ids or [],
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
# Use a simplified query that skips all unnecessary processing
|
||||
minimal_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=SearchType.SEMANTIC,
|
||||
filters=filters,
|
||||
# Set minimal options needed for retrieval
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
recency_bias_multiplier=1.0,
|
||||
chunks_above=0, # No need for surrounding context
|
||||
chunks_below=0, # No need for surrounding context
|
||||
processed_keywords=[], # Empty list instead of None
|
||||
rerank_settings=None,
|
||||
hybrid_alpha=0.0,
|
||||
max_llm_filter_sections=0,
|
||||
)
|
||||
|
||||
# Retrieve chunks using the minimal configuration
|
||||
return retrieve_chunks(
|
||||
query=minimal_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
@@ -339,6 +380,12 @@ class SearchPipeline:
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def merged_retrieved_sections(self) -> list[InferenceSection]:
|
||||
"""Should be used to display in the UI in order to prevent displaying
|
||||
multiple sections for the same document as separate "documents"."""
|
||||
return _merge_sections(sections=self.retrieved_sections)
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@@ -385,6 +432,10 @@ class SearchPipeline:
|
||||
self.search_query.evaluation_type == LLMEvaluationType.SKIP
|
||||
or DISABLE_LLM_DOC_RELEVANCE
|
||||
):
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info(
|
||||
"Fast path: Skipping section relevance evaluation for ordering-only mode"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
|
||||
@@ -415,6 +466,10 @@ class SearchPipeline:
|
||||
raise ValueError(
|
||||
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
||||
)
|
||||
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
|
||||
# since the property sets the generator. DO NOT REMOVE.
|
||||
_ = self.final_context_sections
|
||||
|
||||
self._section_relevance = next(
|
||||
cast(
|
||||
Iterator[list[SectionRelevancePiece]],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user