mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
4 Commits
new_image_
...
playwright
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ce02ccc01 | ||
|
|
d78c8e2e05 | ||
|
|
501ad93153 | ||
|
|
f4686440ae |
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@@ -1 +0,0 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
@@ -12,40 +12,29 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
|
||||
# Bypassing this for now as the idea of not building is glitching
|
||||
# releases and builds that depends on everything being tagged in docker
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
# check_model_server_changes:
|
||||
# runs-on: ubuntu-latest
|
||||
# outputs:
|
||||
# changed: ${{ steps.check.outputs.changed }}
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Check if relevant files changed
|
||||
# id: check
|
||||
# run: |
|
||||
# # Default to "false"
|
||||
# echo "changed=false" >> $GITHUB_OUTPUT
|
||||
#
|
||||
# # Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# # set changed=true
|
||||
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
# echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# fi
|
||||
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: "true"
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
steps:
|
||||
- name: Bypass check and set output
|
||||
run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
|
||||
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,90 +53,24 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
# sarif_file: trivy-results.sarif
|
||||
|
||||
5
.github/workflows/pr-integration-tests.yml
vendored
5
.github/workflows/pr-integration-tests.yml
vendored
@@ -145,7 +145,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -157,7 +157,6 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -200,7 +199,7 @@ jobs:
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
@@ -48,13 +47,11 @@ env:
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -77,8 +74,6 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
97
.github/workflows/pr-python-model-tests.yml
vendored
97
.github/workflows/pr-python-model-tests.yml
vendored
@@ -1,29 +1,18 @@
|
||||
name: Model Server Tests
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
@@ -37,23 +26,6 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -69,49 +41,6 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
@@ -127,23 +56,3 @@ jobs:
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
123
README.md
123
README.md
@@ -24,94 +24,113 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Feature Highlights</h3>
|
||||
<h3>Usage</h3>
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
Onyx Web App:
|
||||
|
||||
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
|
||||
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
**Use Onyx as a secure AI Chat with any LLM:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Easily set up connectors to your apps:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Access Onyx where your team already works:**
|
||||
|
||||

|
||||
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Main Features
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
- Chat UI with the ability to select documents to chat with.
|
||||
- Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
- Document Search + AI Answers for natural language queries.
|
||||
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
- Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
## 🚧 Roadmap
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
- Chat/Prompt sharing with specific teammates and user groups.
|
||||
- Multimodal model support, chat with images, video etc.
|
||||
- Choosing between LLMs and parameters during chat session.
|
||||
- Tool calling and agent configurations options.
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
## Other Notable Benefits of Onyx
|
||||
|
||||
- User Authentication with document level access management.
|
||||
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
- Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
- Custom deep learning models + learn from user feedback.
|
||||
- Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
Efficiently pulls the latest changes from:
|
||||
|
||||
- Slack
|
||||
- GitHub
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Slack
|
||||
- Gmail
|
||||
- Salesforce
|
||||
- Microsoft Sharepoint
|
||||
- Github
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gmail
|
||||
- Notion
|
||||
- Gong
|
||||
- Microsoft Teams
|
||||
- Dropbox
|
||||
- Slab
|
||||
- Linear
|
||||
- Productboard
|
||||
- Guru
|
||||
- Bookstack
|
||||
- Document360
|
||||
- Sharepoint
|
||||
- Hubspot
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
## 📚 Editions
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing).
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
- Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
- Role-based access control
|
||||
- Document permission inheritance from connected sources
|
||||
- Usage analytics and query history accessible to admins
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add indexes to document__tag
|
||||
|
||||
Revision ID: 1a03d2c2856b
|
||||
Revises: 9c00a2bccb83
|
||||
Create Date: 2025-02-18 10:45:13.957807
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1a03d2c2856b"
|
||||
down_revision = "9c00a2bccb83"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_document__tag_tag_id"),
|
||||
"document__tag",
|
||||
["tag_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Update GitHub connector repo_name to repositories
|
||||
|
||||
Revision ID: 3934b1bc7b62
|
||||
Revises: b7c2b63c4a03
|
||||
Create Date: 2025-03-05 10:50:30.516962
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import logging
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3934b1bc7b62"
|
||||
down_revision = "b7c2b63c4a03"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all GitHub connectors
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
updated_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
logger.warning(f"Connector {connector_id} has no config, skipping")
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repo_name" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repositories instead of repo_name
|
||||
new_config = dict(config)
|
||||
repo_name_value = new_config.pop("repo_name")
|
||||
new_config["repositories"] = repo_name_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating connector {connector_id}: {str(e)}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
logger.debug(
|
||||
"Starting rollback of GitHub connectors from repositories to repo_name"
|
||||
)
|
||||
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
|
||||
|
||||
# Revert each GitHub connector to use repo_name instead of repositories
|
||||
reverted_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repositories" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repo_name instead of repositories
|
||||
new_config = dict(config)
|
||||
repositories_value = new_config.pop("repositories")
|
||||
new_config["repo_name"] = repositories_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"new_config": json.dumps(new_config), "connector_id": connector_id},
|
||||
)
|
||||
reverted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting connector {connector_id}: {str(e)}")
|
||||
@@ -1,84 +0,0 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Conflicts on lowercasing will result in the uppercased email getting a
|
||||
unique integer suffix when converted to lowercase."""
|
||||
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Fetch all user emails that are not already lowercase
|
||||
user_emails = connection.execute(
|
||||
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
|
||||
).fetchall()
|
||||
|
||||
for user_id, email in user_emails:
|
||||
email = cast(str, email)
|
||||
username, domain = email.rsplit("@", 1)
|
||||
new_email = f"{username.lower()}@{domain.lower()}"
|
||||
attempt = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try updating the email
|
||||
connection.execute(
|
||||
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
|
||||
{"new_email": new_email, "user_id": user_id},
|
||||
)
|
||||
break # Success, exit loop
|
||||
except IntegrityError:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add index
|
||||
|
||||
Revision ID: 8f43500ee275
|
||||
Revises: da42808081e3
|
||||
Create Date: 2025-02-24 17:35:33.072714
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8f43500ee275"
|
||||
down_revision = "da42808081e3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -1,43 +0,0 @@
|
||||
"""chat_message_agentic
|
||||
|
||||
Revision ID: 9c00a2bccb83
|
||||
Revises: b7a7eee5aa15
|
||||
Create Date: 2025-02-17 11:15:43.081150
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c00a2bccb83"
|
||||
down_revision = "b7a7eee5aa15"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First add the column as nullable
|
||||
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
|
||||
|
||||
# Update existing rows based on presence of SubQuestions
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET is_agentic = EXISTS (
|
||||
SELECT 1
|
||||
FROM agent__sub_question
|
||||
WHERE agent__sub_question.primary_question_id = chat_message.id
|
||||
)
|
||||
WHERE is_agentic IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the column non-nullable with a default value of False
|
||||
op.alter_column(
|
||||
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove inactive ccpair status on downgrade
|
||||
|
||||
Revision ID: acaab4ef4507
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-16 18:21:41.330212
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from sqlalchemy import update
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "acaab4ef4507"
|
||||
down_revision = "b388730a2899"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
|
||||
.values(status=ConnectorCredentialPairStatus.ACTIVE)
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""nullable preferences
|
||||
|
||||
Revision ID: b388730a2899
|
||||
Revises: 1a03d2c2856b
|
||||
Create Date: 2025-02-17 18:49:22.643902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b388730a2899"
|
||||
down_revision = "1a03d2c2856b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=True)
|
||||
op.alter_column("user", "auto_scroll", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Ensure no null values before making columns non-nullable
|
||||
op.execute(
|
||||
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
|
||||
)
|
||||
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
|
||||
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=False)
|
||||
op.alter_column("user", "auto_scroll", nullable=False)
|
||||
@@ -1,55 +0,0 @@
|
||||
"""add background_reindex_enabled field
|
||||
|
||||
Revision ID: b7c2b63c4a03
|
||||
Revises: f11b408e39d3
|
||||
Create Date: 2024-03-26 12:34:56.789012
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7c2b63c4a03"
|
||||
down_revision = "f11b408e39d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add background_reindex_enabled column with default value of True
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"background_reindex_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
)
|
||||
|
||||
# Add embedding_precision column with default value of FLOAT
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"embedding_precision",
|
||||
sa.Enum(EmbeddingPrecision, native_enum=False),
|
||||
nullable=False,
|
||||
server_default=EmbeddingPrecision.FLOAT.name,
|
||||
),
|
||||
)
|
||||
|
||||
# Add reduced_dimension column with default value of None
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the background_reindex_enabled column
|
||||
op.drop_column("search_settings", "background_reindex_enabled")
|
||||
op.drop_column("search_settings", "embedding_precision")
|
||||
op.drop_column("search_settings", "reduced_dimension")
|
||||
@@ -1,120 +0,0 @@
|
||||
"""migrate jira connectors to new format
|
||||
|
||||
Revision ID: da42808081e3
|
||||
Revises: f13db29f3101
|
||||
Create Date: 2025-02-24 11:24:54.396040
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da42808081e3"
|
||||
down_revision = "f13db29f3101"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
for connector_id, old_config in jira_connectors:
|
||||
if not old_config:
|
||||
continue
|
||||
|
||||
# Extract project key from URL if it exists
|
||||
new_config: dict[str, str | None] = {}
|
||||
if project_url := old_config.get("jira_project_url"):
|
||||
# Parse the URL to get base and project
|
||||
try:
|
||||
jira_base, project_key = extract_jira_project(project_url)
|
||||
new_config = {"jira_base_url": jira_base, "project_key": project_key}
|
||||
except ValueError:
|
||||
# If URL parsing fails, just use the URL as the base
|
||||
new_config = {
|
||||
"jira_base_url": project_url.split("/projects/")[0],
|
||||
"project_key": None,
|
||||
}
|
||||
else:
|
||||
# For connectors without a project URL, we need admin intervention
|
||||
# Mark these for review
|
||||
print(
|
||||
f"WARNING: Jira connector {connector_id} has no project URL configured"
|
||||
)
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config back to the old format
|
||||
for connector_id, new_config in jira_connectors:
|
||||
if not new_config:
|
||||
continue
|
||||
|
||||
old_config = {}
|
||||
base_url = new_config.get("jira_base_url")
|
||||
project_key = new_config.get("project_key")
|
||||
|
||||
if base_url and project_key:
|
||||
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
|
||||
elif base_url:
|
||||
old_config = {"jira_project_url": base_url}
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :old_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "old_config": old_config},
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""add_default_vision_provider_to_llm_provider
|
||||
|
||||
Revision ID: df46c75b714e
|
||||
Revises: 3934b1bc7b62
|
||||
Create Date: 2025-03-11 16:20:19.038945
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df46c75b714e"
|
||||
down_revision = "3934b1bc7b62"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_default_vision_provider",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||
)
|
||||
# Add unique constraint for is_default_vision_provider
|
||||
op.create_unique_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider",
|
||||
"llm_provider",
|
||||
["is_default_vision_provider"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
|
||||
)
|
||||
op.drop_column("llm_provider", "default_vision_model")
|
||||
op.drop_column("llm_provider", "is_default_vision_provider")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add composite index for last_modified and last_synced to document
|
||||
|
||||
Revision ID: f13db29f3101
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-18 22:48:11.511389
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f13db29f3101"
|
||||
down_revision = "acaab4ef4507"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_document_sync_status",
|
||||
"document",
|
||||
["last_modified", "last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_document_sync_status", table_name="document")
|
||||
@@ -1,42 +0,0 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add new available tenant table
|
||||
|
||||
Revision ID: 3b45e0018bf1
|
||||
Revises: ac842f85f932
|
||||
Create Date: 2025-03-06 09:55:18.229910
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b45e0018bf1"
|
||||
down_revision = "ac842f85f932"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create new_available_tenant table
|
||||
op.create_table(
|
||||
"available_tenant",
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop new_available_tenant table
|
||||
op.drop_table("available_tenant")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""new column user tenant mapping
|
||||
|
||||
Revision ID: ac842f85f932
|
||||
Revises: 34e3630c7f32
|
||||
Create Date: 2025-03-03 13:30:14.802874
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac842f85f932"
|
||||
down_revision = "34e3630c7f32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add active column with default value of True
|
||||
op.add_column(
|
||||
"user_tenant_mapping",
|
||||
sa.Column(
|
||||
"active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
|
||||
|
||||
# Create a unique index for active=true records
|
||||
# This ensures a user can only be active in one tenant at a time
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the unique index for active=true records
|
||||
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_email", "user_tenant_mapping", ["email"], schema="public"
|
||||
)
|
||||
|
||||
# Remove the active column
|
||||
op.drop_column("user_tenant_mapping", "active", schema="public")
|
||||
@@ -4,11 +4,12 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -17,28 +18,11 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
#####
|
||||
@@ -51,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -71,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -59,14 +59,10 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
|
||||
@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
|
||||
def get_cc_pairs_by_source(
|
||||
db_session: Session,
|
||||
source_type: DocumentSource,
|
||||
access_type: AccessType | None = None,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
only_sync: bool,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all cc_pairs for a given source type with optional filtering by access_type and status
|
||||
Get all cc_pairs for a given source type (and optionally only sync)
|
||||
result is sorted by cc_pair id
|
||||
"""
|
||||
query = (
|
||||
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
|
||||
.order_by(ConnectorCredentialPair.id)
|
||||
)
|
||||
|
||||
if access_type is not None:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == access_type)
|
||||
|
||||
if status is not None:
|
||||
query = query.filter(ConnectorCredentialPair.status == status)
|
||||
if only_sync:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
|
||||
|
||||
cc_pairs = query.all()
|
||||
return cc_pairs
|
||||
|
||||
@@ -134,9 +134,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
"""Sorted by oldest to newest, then by message id"""
|
||||
|
||||
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
filters: list[ColumnElement | BinaryExpression] = [
|
||||
@@ -149,7 +147,8 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
subquery = (
|
||||
db_session.query(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*filters)
|
||||
.order_by(asc_time_order)
|
||||
.order_by(ChatSession.id, time_order)
|
||||
.distinct(ChatSession.id)
|
||||
.limit(limit)
|
||||
.subquery()
|
||||
)
|
||||
@@ -165,7 +164,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(asc_time_order, message_order)
|
||||
.order_by(time_order, message_order)
|
||||
)
|
||||
|
||||
chat_sessions = query.all()
|
||||
|
||||
@@ -16,20 +16,13 @@ from onyx.db.models import UsageReport
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
|
||||
|
||||
# Gets skeletons of all messages in the given range
|
||||
# Gets skeletons of all message
|
||||
def get_empty_chat_messages_entries__paginated(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
|
||||
"""Returns a tuple where:
|
||||
first element is the most recent timestamp out of the sessions iterated
|
||||
- this timestamp can be used to paginate forward in time
|
||||
second element is a list of messages belonging to all the sessions iterated
|
||||
|
||||
Only messages of type USER are returned
|
||||
"""
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=period[0],
|
||||
end=period[1],
|
||||
@@ -59,17 +52,18 @@ def get_empty_chat_messages_entries__paginated(
|
||||
if len(chat_sessions) == 0:
|
||||
return None, []
|
||||
|
||||
return chat_sessions[-1].time_created, message_skeletons
|
||||
return chat_sessions[0].time_created, message_skeletons
|
||||
|
||||
|
||||
def get_all_empty_chat_message_entries(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
) -> Generator[list[ChatMessageSkeleton], None, None]:
|
||||
"""period is the range of time over which to fetch messages."""
|
||||
initial_time: Optional[datetime] = period[0]
|
||||
ind = 0
|
||||
while True:
|
||||
# iterate from oldest to newest
|
||||
ind += 1
|
||||
|
||||
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
|
||||
db_session,
|
||||
period,
|
||||
|
||||
@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC only if they were previously a CURATOR
|
||||
# otherwise, set their role to BASIC
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
@@ -631,16 +631,7 @@ def update_user_group(
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
users_to_validate = [
|
||||
user
|
||||
for user in removed_users
|
||||
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
|
||||
]
|
||||
|
||||
if users_to_validate:
|
||||
_validate_curator_status__no_commit(db_session, users_to_validate)
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
@@ -9,16 +9,12 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -346,8 +342,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -359,11 +354,7 @@ def confluence_doc_sync(
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "confluence", cc_pair.credential_id
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -16,24 +14,30 @@ def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
email = user.email
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.username
|
||||
user_name = user.get("username")
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
msg = f"user result missing email field: {user}"
|
||||
if user.type == "app":
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
@@ -41,7 +45,7 @@ def _build_group_member_email_map(
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
@@ -63,27 +67,13 @@ def _build_group_member_email_map(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
url = wiki_base.rstrip("/")
|
||||
|
||||
probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
confluence_client = OnyxConfluence(is_cloud, url, provider)
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
|
||||
@@ -32,8 +32,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -62,14 +62,12 @@ def _fetch_permissions_for_permission_ids(
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain)",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
@@ -106,13 +104,7 @@ def _get_permissions_from_slim_doc(
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
public = False
|
||||
skipped_permissions = 0
|
||||
|
||||
for permission in permissions_list:
|
||||
if not permission:
|
||||
skipped_permissions += 1
|
||||
continue
|
||||
|
||||
permission_type = permission["type"]
|
||||
if permission_type == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
@@ -129,11 +121,6 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
if skipped_permissions > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
@@ -145,8 +132,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -119,7 +119,6 @@ def _build_onyx_groups(
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
|
||||
@@ -123,8 +123,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -28,7 +28,6 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.oauth import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
@@ -152,8 +152,4 @@ def get_application() -> FastAPI:
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
return application
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
|
||||
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -216,7 +216,7 @@ def _handle_standard_answers(
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread_or_channel(
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
@@ -231,7 +231,6 @@ def _handle_standard_answers(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
receiver_ids=receiver_ids,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in tenant ID middleware: {str(e)}")
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
|
||||
631
backend/ee/onyx/server/oauth.py
Normal file
631
backend/ee/onyx/server/oauth.py
Normal file
@@ -0,0 +1,631 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
@@ -1,91 +0,0 @@
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
session: str | None = None
|
||||
if connector == DocumentSource.SLACK:
|
||||
if not DEV_MODE:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
|
||||
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.CONFLUENCE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = ConfluenceCloudOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"The document source type {connector} failed to generate an OAuth session.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
@@ -1,3 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/oauth")
|
||||
@@ -1,362 +0,0 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
class AccessibleResources(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
scopes: list[str]
|
||||
avatarUrl: str
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
ACCESSIBLE_RESOURCE_URL = (
|
||||
"https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
)
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
# classic scope
|
||||
"read:confluence-space.summary%20"
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"read:content-details:confluence%20" # for permission sync
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
|
||||
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def generate_finalize_url(cls, credential_id: int) -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
|
||||
|
||||
|
||||
@router.post("/connector/confluence/callback")
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Handles the backend logic for the frontend page that the user is redirected to
|
||||
after visiting the oauth authorization url."""
|
||||
|
||||
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Confluence Cloud client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = ConfluenceCloudOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
ConfluenceCloudOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
|
||||
|
||||
try:
|
||||
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
|
||||
response.text
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Confluence Cloud OAuth failed during code/token exchange."
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={
|
||||
"confluence_access_token": token_response.access_token,
|
||||
"confluence_refresh_token": token_response.refresh_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"expires_in": token_response.expires_in,
|
||||
"scope": token_response.scope,
|
||||
},
|
||||
admin_public=True,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
name="Confluence Cloud OAuth",
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth completed successfully.",
|
||||
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Atlassian's API is weird and does not supply us with enough info to be in a
|
||||
usable state after authorizing. All API's require a cloud id. We have to list
|
||||
the accessible resources/sites and let the user choose which site to use."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.get(
|
||||
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
accessible_resources_data = response.json()
|
||||
|
||||
# Validate the list of AccessibleResources
|
||||
try:
|
||||
accessible_resources = [
|
||||
ConfluenceCloudOAuth.AccessibleResources(**resource)
|
||||
for resource in accessible_resources_data
|
||||
]
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Failed to parse accessible resources: {e}")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud get accessible resources completed successfully.",
|
||||
"accessible_resources": [
|
||||
resource.model_dump() for resource in accessible_resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector/confluence/finalize")
|
||||
def confluence_oauth_finalize(
|
||||
credential_id: int,
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Saves the info for the selected cloud site to the credential.
|
||||
This is the final step in the confluence oauth flow where after the traditional
|
||||
OAuth process, the user has to select a site to associate with the credentials.
|
||||
After this, the credential is usable."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
try:
|
||||
update_credential_json(credential_id, new_credential_json, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth finalized successfully.",
|
||||
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
|
||||
}
|
||||
)
|
||||
@@ -1,229 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
@@ -1,197 +0,0 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = SlackOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
}
|
||||
)
|
||||
@@ -1,14 +1,10 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
||||
@@ -2,7 +2,6 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -48,15 +43,10 @@ def fetch_and_process_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionSnapshot]:
|
||||
# observed to be slow a scale of 8192 sessions and 4 messages per session
|
||||
|
||||
# this is a little slow (5 seconds)
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
# this is VERY slow (80 seconds) due to create_chat_chain being called
|
||||
# for each session. Needs optimizing.
|
||||
chat_session_snapshots = [
|
||||
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
|
||||
for chat_session in chat_sessions
|
||||
@@ -117,17 +107,6 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -143,7 +122,6 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -163,12 +141,6 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -185,16 +157,11 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=minimal_chat_sessions,
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -205,12 +172,6 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -232,9 +193,6 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -245,14 +203,6 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
# this call is very expensive and is timing out via endpoint
|
||||
# TODO: optimize call and/or generate via background task
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -263,9 +213,6 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
@@ -1,98 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
@@ -1,24 +1,261 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.user_invitations_api import (
|
||||
router as user_invitations_router,
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
# Create a main router to include all sub-routers
|
||||
# Note: We don't add a prefix here as each router already has the /tenants prefix
|
||||
router = APIRouter()
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Include all the individual routers
|
||||
router.include_router(admin_router)
|
||||
router.include_router(anonymous_users_router)
|
||||
router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -7,7 +7,6 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -42,9 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -55,19 +52,8 @@ def fetch_billing_information(
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
creator_email: str
|
||||
|
||||
|
||||
class TenantByDomainRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class RequestInviteRequest(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class RequestInviteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PendingUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@@ -48,5 +48,4 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
|
||||
@@ -4,7 +4,6 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
@@ -58,77 +55,43 @@ logger = logging.getLogger(__name__)
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
# Early return for non-multi-tenant mode
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
# First, check if the user already has a tenant
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
return tenant_id
|
||||
except exceptions.UserNotExists:
|
||||
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to get a pre-provisioned tenant
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
error_msg = "Failed to provision tenant"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to provision tenant. Please try again later.",
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||
|
||||
"""
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
try:
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
@@ -141,26 +104,55 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
logger.info(f"Provisioning tenant: {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Create the schema for the tenant
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
await setup_tenant(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Assign the tenant to the user
|
||||
await assign_tenant_to_user(tenant_id, email)
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
@@ -191,106 +183,31 @@ async def notify_control_plane(
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
"""
|
||||
Logic to rollback tenant provisioning on data plane.
|
||||
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||
"""
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
|
||||
# Track if any part of the rollback fails
|
||||
rollback_errors = []
|
||||
|
||||
# 1. Try to drop the tenant's schema
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 2. Try to remove tenant mapping
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 3. If this tenant was in the available tenants table, remove it
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Removed tenant {tenant_id} from available tenants table"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# Log summary of rollback operation
|
||||
if rollback_errors:
|
||||
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||
else:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
@@ -302,6 +219,25 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
@@ -411,155 +347,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_by_domain_from_control_plane(
|
||||
domain: str,
|
||||
tenant_id: str,
|
||||
) -> TenantByDomainResponse | None:
|
||||
"""
|
||||
Fetches tenant information from the control plane based on the email domain.
|
||||
|
||||
Args:
|
||||
domain: The email domain to search for (e.g., "example.com")
|
||||
|
||||
Returns:
|
||||
A dictionary containing tenant information if found, None otherwise
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
|
||||
headers=headers,
|
||||
json={"domain": domain, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Control plane tenant lookup failed: {response.text}")
|
||||
return None
|
||||
|
||||
response_data = response.json()
|
||||
if not response_data:
|
||||
return None
|
||||
|
||||
return TenantByDomainResponse(
|
||||
tenant_id=response_data.get("tenant_id"),
|
||||
number_of_users=response_data.get("number_of_users"),
|
||||
creator_email=response_data.get("creator_email"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_available_tenant() -> str | None:
|
||||
"""
|
||||
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||
Returns the tenant_id if one is available, None otherwise.
|
||||
Uses row-level locking to prevent race conditions when multiple processes
|
||||
try to get an available tenant simultaneously.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
db_session.begin()
|
||||
|
||||
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.order_by(AvailableTenant.date_created)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
tenant_id = available_tenant.tenant_id
|
||||
# Remove the tenant from the available tenants table
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||
return tenant_id
|
||||
else:
|
||||
db_session.rollback()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error getting available tenant")
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
async def setup_tenant(tenant_id: str) -> None:
|
||||
"""
|
||||
Set up a tenant with all necessary configurations.
|
||||
This is a centralized function that handles all tenant setup logic.
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Configure default API keys
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Set up Onyx with appropriate settings
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||
raise e
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
Uses transaction handling to ensure atomicity and includes retry logic
|
||||
for control plane notifications.
|
||||
"""
|
||||
# First, add the user to the tenant in a transaction
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
|
||||
|
||||
def get_current_alembic_version(tenant_id: str) -> str:
|
||||
"""Get the current Alembic version for a tenant."""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Set the search path to the tenant's schema
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Get the current version from the alembic_version table
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
return current_rev or "head"
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
@@ -1,39 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
"gmail",
|
||||
"outlook",
|
||||
"yahoo",
|
||||
"hotmail",
|
||||
"icloud",
|
||||
"msn",
|
||||
"hotmail",
|
||||
"hotmail.co.uk",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
return None
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return get_tenant_by_domain_from_control_plane(domain, tenant_id)
|
||||
@@ -1,90 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.models import ApproveUserRequest
|
||||
from ee.onyx.server.tenants.models import PendingUserSnapshot
|
||||
from ee.onyx.server.tenants.models import RequestInviteRequest
|
||||
from ee.onyx.server.tenants.user_mapping import accept_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
|
||||
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
|
||||
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to accept invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to accept invitation")
|
||||
|
||||
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to deny invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to deny invitation")
|
||||
@@ -1,63 +1,34 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import write_pending_users
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# First try to get an active tenant
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
tenant_id = mapping.tenant_id if mapping else None
|
||||
|
||||
# If no active tenant found, try to get the first inactive one
|
||||
if tenant_id is None:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
if mapping:
|
||||
# Mark this mapping as active
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -67,43 +38,17 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
raise
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -126,192 +71,8 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
return
|
||||
write_pending_users(pending_users + [email])
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Approve a user invite to a tenant.
|
||||
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete all existing records for this email
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email
|
||||
).delete()
|
||||
|
||||
# Create a new mapping entry for the user in this tenant
|
||||
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_pending_users(pending_users)
|
||||
|
||||
# Add to invited users
|
||||
invited_users = get_invited_users()
|
||||
if email not in invited_users:
|
||||
invited_users.append(email)
|
||||
write_invited_users(invited_users)
|
||||
|
||||
|
||||
def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
logger.info(
|
||||
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Find the inactive mapping for this user and tenant
|
||||
mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# Set all other mappings for this user to inactive
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
).update({"active": False})
|
||||
|
||||
# Activate this mapping
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
This removes the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete the mapping for this user and tenant
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
if result:
|
||||
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_invited_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_invited_users(pending_users)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return user_count
|
||||
|
||||
|
||||
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
|
||||
"""
|
||||
Get the first tenant invitation for this user
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Get the first tenant invitation for this user
|
||||
invitation = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if invitation:
|
||||
# Get the user count for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == invitation.tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return TenantSnapshot(
|
||||
tenant_id=invitation.tenant_id, number_of_users=user_count
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -6,7 +6,7 @@ MODEL_WARM_UP_STRING = "hi " * 512
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-005"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
|
||||
@@ -5,7 +5,6 @@ from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
@@ -29,13 +28,11 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.utils import pass_aws_key
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
@@ -62,60 +59,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
@@ -135,7 +78,7 @@ class CloudEmbedding:
|
||||
self._closed = False
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
@@ -146,17 +89,22 @@ class CloudEmbedding:
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
f"Model: {model} \n"
|
||||
f"Provider: {self.provider} \n"
|
||||
f"Texts: {texts}"
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
logger.error(error_string)
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
@@ -195,6 +143,7 @@ class CloudEmbedding:
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
@@ -224,24 +173,17 @@ class CloudEmbedding:
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
||||
|
||||
# Split into batches of 25 texts
|
||||
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
batches = [
|
||||
inputs[i : i + max_texts_per_batch]
|
||||
for i in range(0, len(inputs), max_texts_per_batch)
|
||||
]
|
||||
|
||||
# Dispatch all embedding calls asynchronously at once
|
||||
tasks = [
|
||||
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete in parallel
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [embedding.values for batch in results for embedding in batch]
|
||||
embeddings = await client.get_embeddings_async(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # This is the default
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
@@ -276,53 +218,23 @@ class CloudEmbedding:
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e, str(self.provider), model_name or deployment_name, self.provider
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -409,7 +321,6 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
reduced_dimension: int | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
@@ -453,7 +364,6 @@ async def embed_text(
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
reduced_dimension=reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
@@ -525,7 +435,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
)
|
||||
|
||||
|
||||
async def cohere_rerank_api(
|
||||
async def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
@@ -535,45 +445,6 @@ async def cohere_rerank_api(
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def cohere_rerank_aws(
|
||||
query: str,
|
||||
docs: list[str],
|
||||
model_name: str,
|
||||
region_name: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> list[float]:
|
||||
session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
||||
)
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region_name
|
||||
) as bedrock_client:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"api_version": 2,
|
||||
}
|
||||
)
|
||||
# Invoke the Bedrock model asynchronously
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# Read the response asynchronously
|
||||
response_body = json.loads(await response["body"].read())
|
||||
|
||||
# Extract and sort the results
|
||||
results = response_body.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda item: item["index"])
|
||||
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
@@ -632,18 +503,10 @@ async def process_embed_request(
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
@@ -696,32 +559,15 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = await cohere_rerank_api(
|
||||
sim_scores = await cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Bedrock Rerank Requires an API Key")
|
||||
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
||||
rerank_request.api_key
|
||||
)
|
||||
sim_scores = await cohere_rerank_aws(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
|
||||
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
||||
"""Parse AWS API key string into components.
|
||||
|
||||
Args:
|
||||
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
||||
|
||||
Returns:
|
||||
Tuple of (access_key, secret_key, region)
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
"""
|
||||
if not api_key.startswith("aws"):
|
||||
raise ValueError("API key must start with 'aws' prefix")
|
||||
|
||||
parts = api_key.split("_")
|
||||
if len(parts) != 4:
|
||||
raise ValueError(
|
||||
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
|
||||
"this is an onyx specific format for formatting the aws secrets for bedrock"
|
||||
)
|
||||
|
||||
try:
|
||||
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
||||
return aws_access_key_id, aws_secret_access_key, aws_region
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,15 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,12 +85,9 @@ def check_sub_answer(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
@@ -102,7 +96,7 @@ def check_sub_answer(
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -46,14 +47,11 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -112,15 +110,15 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -144,15 +142,8 @@ def generate_sub_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -43,7 +44,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -60,16 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -82,7 +77,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
@@ -155,9 +149,8 @@ def generate_initial_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -237,11 +230,7 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -271,19 +260,15 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -307,16 +292,9 @@ def generate_initial_answer(
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
streamed_tokens.append(content)
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -34,13 +34,9 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -51,7 +47,6 @@ from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -136,13 +131,10 @@ def decompose_orig_question(
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
@@ -162,7 +154,7 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
except LLMTimeoutError as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "call_tool"
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=choose_tool,
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
|
||||
@@ -33,16 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -108,15 +105,11 @@ def compare_answers(
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -43,12 +43,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -57,7 +53,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -139,18 +134,15 @@ def create_refined_sub_questions(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -50,7 +50,13 @@ def decide_refinement_need(
|
||||
)
|
||||
]
|
||||
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
if graph_config.behavior.allow_refinement:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
else:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
@@ -21,19 +21,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -91,43 +84,30 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -46,7 +47,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -66,23 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -101,7 +92,6 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -182,9 +172,8 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -264,12 +253,7 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
@@ -300,17 +284,13 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -335,15 +315,8 @@ def generate_validate_refined_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
@@ -410,21 +383,16 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -143,6 +144,8 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
|
||||
assert query_info is not None, "must have query info"
|
||||
return query_info
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
@@ -33,18 +33,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -72,7 +69,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -91,13 +88,10 @@ def expand_queries(
|
||||
rewritten_queries = []
|
||||
|
||||
try:
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
@@ -107,7 +101,7 @@ def expand_queries(
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -56,9 +56,8 @@ def format_results(
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -55,7 +55,6 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -63,31 +62,23 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -91,7 +91,7 @@ def retrieve_documents(
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0] if callback_container else []
|
||||
pre_rerank_docs = callback_container[0]
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
|
||||
@@ -25,16 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -89,12 +86,8 @@ def verify_documents(
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -103,7 +96,7 @@ def verify_documents(
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -9,23 +9,18 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
@@ -55,13 +50,11 @@ def basic_use_tool_response(
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
@@ -15,17 +15,8 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -34,8 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -47,31 +37,6 @@ def choose_tool(
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
|
||||
)
|
||||
):
|
||||
override_kwargs = SearchToolOverrideKwargs()
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.search_request.query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@@ -82,6 +47,7 @@ def choose_tool(
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@@ -105,22 +71,11 @@ def choose_tool(
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -143,16 +98,8 @@ def choose_tool(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in tools] or None
|
||||
if using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tool_choice=(
|
||||
"required"
|
||||
if tools and force_use_tool.force_use and using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
@@ -198,22 +145,10 @@ def choose_tool(
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def call_tool(
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -44,9 +44,7 @@ def call_tool(
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
@@ -2,7 +2,6 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -36,7 +35,6 @@ class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -13,11 +13,6 @@ AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
|
||||
@@ -42,11 +42,9 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -62,7 +60,6 @@ from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
@@ -83,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -399,14 +395,11 @@ def summarize_history(
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
@@ -508,9 +501,3 @@ def get_deduplicated_structured_subquestion_documents(
|
||||
cited_documents=dedup_inference_section_list(cited_docs),
|
||||
context_documents=dedup_inference_section_list(context_docs),
|
||||
)
|
||||
|
||||
|
||||
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
|
||||
return not (
|
||||
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,8 +2,6 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -12,10 +10,8 @@ from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -155,8 +151,6 @@ def send_email(
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -178,7 +172,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -193,64 +187,36 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
heading = "You've Been Invited!"
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.BASIC:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.GOOGLE_OAUTH:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to login with Google "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to"
|
||||
" complete your registration.</p>"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth type: {auth_type}")
|
||||
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_PENDING_USERS_KEY
|
||||
from onyx.configs.constants import KV_USER_STORE_KEY
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
|
||||
def get_pending_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_PENDING_USERS_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_pending_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -42,5 +42,4 @@ def fetch_no_auth_user(
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -88,6 +86,7 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -95,22 +94,24 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -142,30 +143,6 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
return email or ""
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
lowercase_letters = string.ascii_lowercase
|
||||
uppercase_letters = string.ascii_uppercase
|
||||
digits = string.digits
|
||||
special_characters = string.punctuation
|
||||
|
||||
# Ensure at least one of each required character type
|
||||
password = [
|
||||
secrets.choice(uppercase_letters),
|
||||
secrets.choice(digits),
|
||||
secrets.choice(special_characters),
|
||||
]
|
||||
|
||||
# Fill the rest with a mix of characters
|
||||
remaining_length = 12 - len(password)
|
||||
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
|
||||
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
|
||||
|
||||
# Shuffle the password to randomize the position of the required characters
|
||||
random.shuffle(password)
|
||||
|
||||
return "".join(password)
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
@@ -215,8 +192,8 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -412,7 +389,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User | None = None
|
||||
user: User
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -421,20 +398,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
user = await self.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -445,36 +417,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -524,7 +486,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if user_count == 1:
|
||||
@@ -546,7 +507,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
logger.debug(f"User {user.id} has registered.")
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
@@ -570,7 +531,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -588,20 +549,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"User attempted to login with invalid credentials: {str(e)}"
|
||||
)
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
@@ -640,39 +595,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
|
||||
"""Admin-only. Generate a random password for a user and return it."""
|
||||
user = await self.get(user_id)
|
||||
new_password = generate_password()
|
||||
await self._update(user, {"password": new_password})
|
||||
return new_password
|
||||
|
||||
async def change_password_if_old_matches(
|
||||
self, user: User, old_password: str, new_password: str
|
||||
) -> None:
|
||||
"""
|
||||
For normal users to change password if they know the old one.
|
||||
Raises 400 if old password doesn't match.
|
||||
"""
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
old_password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
# Raise some HTTPException (or your custom exception) if old password is invalid:
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid current password",
|
||||
)
|
||||
|
||||
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
|
||||
if updated_password_hash:
|
||||
user.hashed_password = updated_password_hash
|
||||
|
||||
# Now apply and validate the new password
|
||||
await self._update(user, {"password": new_password})
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -895,11 +817,10 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accessible_user(
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
@@ -1096,12 +1017,6 @@ def get_oauth_router(
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
@@ -1133,14 +1048,9 @@ def get_oauth_router(
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
@@ -1152,7 +1062,6 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -34,7 +33,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -60,35 +58,13 @@ else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
class TenantAwareTask(Task):
|
||||
"""A custom base Task that sets tenant_id in a contextvar before running."""
|
||||
|
||||
abstract = True # So Celery knows not to register this as a real task.
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# Grab tenant_id from the kwargs, or fallback to default if missing.
|
||||
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Set the context var
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Actually run the task now
|
||||
try:
|
||||
return super().__call__(*args, **kwargs)
|
||||
finally:
|
||||
# Clear or reset after the task runs
|
||||
# so it does not leak into any subsequent tasks on the same worker process
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -132,9 +108,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
@@ -225,7 +201,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -311,7 +287,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
@@ -463,6 +439,24 @@ class TenantContextFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
|
||||
@@ -132,7 +132,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
|
||||
return new_schedule
|
||||
@@ -257,4 +256,3 @@ def on_setup_logging(
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
celery_app.conf.task_default_base = app_base.TenantAwareTask
|
||||
|
||||
@@ -20,7 +20,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.heavy")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user