Compare commits

..

3 Commits

Author SHA1 Message Date
pablodanswer
09e6bd3c9c k 2024-12-18 20:01:44 -08:00
pablodanswer
c1803cdd56 log 2024-12-18 19:20:55 -08:00
pablodanswer
a5b9c76012 validation 2024-12-18 19:13:09 -08:00
919 changed files with 19103 additions and 62680 deletions

View File

@@ -1,14 +1,29 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Accepted Risk (provide if relevant)
N/A
## Related Issue(s) (provide if relevant)
N/A
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] [Optional] Override Linear Check

View File

@@ -66,9 +66,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -4,9 +4,6 @@ on:
push:
tags:
- "*"
paths:
- 'backend/model_server/**'
- 'backend/Dockerfile.model_server'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
@@ -121,6 +118,6 @@ jobs:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -60,8 +60,6 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -8,8 +8,6 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.17.0
version: v3.14.4
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.7.0
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -46,7 +62,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.12.0
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -94,19 +94,16 @@ jobs:
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
@@ -122,10 +119,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
@@ -133,17 +126,17 @@ jobs:
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Multi-tenant integration tests failed. Exiting with error."
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All multi-tenant integration tests passed successfully."
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
run: |
@@ -223,30 +216,27 @@ jobs:
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -1,29 +0,0 @@
name: Ensure PR references Linear
on:
pull_request:
types: [opened, edited, reopened, synchronize]
jobs:
linear-check:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."
exit 0
fi
# Looking for a checked override: "[x] Override Linear Check"
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
echo "Override box is checked. Check passed."
exit 0
fi
# Otherwise, fail the run
echo "No Linear link or override found in the PR description."
exit 1

View File

@@ -26,24 +26,6 @@ env:
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Zendesk
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:

4
.gitignore vendored
View File

@@ -7,6 +7,4 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json
/web/test-results/

View File

@@ -5,8 +5,6 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Skip warm up for dev
SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout
@@ -29,7 +27,6 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
@@ -52,9 +49,3 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -28,7 +28,6 @@
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@@ -52,8 +51,7 @@
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
"Celery beat"
],
"presentation": {
"group": "1",
@@ -271,31 +269,6 @@
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
@@ -382,20 +355,5 @@
"PYTHONPATH": "."
},
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
]
}

View File

@@ -12,10 +12,6 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
Issues marked `good first issue` are an especially great place to start.
@@ -27,8 +23,8 @@ If you have a new/different contribution in mind, we'd love to hear about it!
Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
And always feel free to message us (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
### Contributing Code
@@ -46,7 +42,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
and
[Discord](https://discord.gg/TDJ59cGV2X).
@@ -127,47 +123,7 @@ Once the above is done, navigate to `onyx/web` run:
npm i
```
## Formatting and Linting
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.
## Manually running the application for development
### Docker containers for external software
#### Docker containers for external software
You will need Docker installed to run these containers.
@@ -179,7 +135,7 @@ docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
### Running Onyx locally
#### Running Onyx locally
To start the frontend, navigate to `onyx/web` and run:
@@ -267,6 +223,35 @@ If you want to make changes to Onyx and run those changes in Docker, you can als
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
```
### Formatting and Linting
#### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
```bash
pip install pre-commit
```
Then, from the `onyx/backend` directory, run:
```bash
pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
#### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
### Release Process

View File

@@ -1,30 +0,0 @@
# VSCode Debugging Setup
This guide explains how to set up and use VSCode's debugging capabilities with this project.
## Initial Setup
1. **Environment Setup**:
- Copy `.vscode/.env.template` to `.vscode/.env`
- Fill in the necessary environment variables in `.vscode/.env`
2. **launch.json**:
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
## Using the Debugger
Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
## Features
- Hot reload is enabled for the web server and API servers
- Python debugging is configured with debugpy
- Environment variables are loaded from `.vscode/.env`
- Console output is organized in the integrated terminal with labeled tabs

View File

@@ -3,7 +3,7 @@
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
</h2>
<p align="center">
@@ -13,7 +13,7 @@
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -24,7 +24,7 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
@@ -119,12 +119,12 @@ There are two editions of Onyx:
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
@@ -134,3 +134,14 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

1
backend/.gitignore vendored
View File

@@ -9,4 +9,3 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/

View File

@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true"
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
@@ -101,8 +99,7 @@ COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets

View File

@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import AWS_REGION
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION_NAME
region = AWS_REGION
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER

View File

@@ -1,29 +0,0 @@
"""add shortcut option for users
Revision ID: 027381bce97c
Revises: 6fc7886d665d
Create Date: 2025-01-14 12:14:00.814390
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "027381bce97c"
down_revision = "6fc7886d665d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
op.drop_column("user", "shortcut_enabled")

View File

@@ -1,36 +0,0 @@
"""add index to index_attempt.time_created
Revision ID: 0f7ff6d75b57
Revises: 369644546676
Create Date: 2025-01-10 14:01:14.067144
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0f7ff6d75b57"
down_revision = "fec3db967bf7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_index_attempt_status"),
"index_attempt",
["status"],
unique=False,
)
op.create_index(
op.f("ix_index_attempt_time_created"),
"index_attempt",
["time_created"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")

View File

@@ -1,24 +0,0 @@
"""add chunk count to document
Revision ID: 2955778aa44c
Revises: c0aab6edb6dd
Create Date: 2025-01-04 11:39:43.268612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2955778aa44c"
down_revision = "c0aab6edb6dd"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
def downgrade() -> None:
op.drop_column("document", "chunk_count")

View File

@@ -1,32 +0,0 @@
"""set built in to default
Revision ID: 2cdeff6d8c93
Revises: f5437cc136c5
Create Date: 2025-02-11 14:57:51.308775
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2cdeff6d8c93"
down_revision = "f5437cc136c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Prior to this migration / point in the codebase history,
# built in personas were implicitly treated as default personas (with no option to change this)
# This migration makes that explicit
op.execute(
"""
UPDATE persona
SET is_default_persona = TRUE
WHERE builtin_persona = TRUE
"""
)
def downgrade() -> None:
pass

View File

@@ -1,36 +0,0 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -1,80 +0,0 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -1,35 +0,0 @@
"""add composite index for index attempt time updated
Revision ID: 369644546676
Revises: 2955778aa44c
Create Date: 2025-01-08 15:38:17.224380
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "369644546676"
down_revision = "2955778aa44c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
text("time_updated DESC"),
],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
table_name="index_attempt",
)

View File

@@ -1,59 +0,0 @@
"""add back input prompts
Revision ID: 3c6531f32351
Revises: aeda5f2df4f6
Create Date: 2025-01-13 12:49:51.705235
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "3c6531f32351"
down_revision = "aeda5f2df4f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -40,6 +40,6 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,37 +0,0 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -5,6 +5,7 @@ Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
@@ -19,8 +20,13 @@ down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
@@ -57,6 +63,7 @@ def upgrade() -> None:
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
@@ -64,12 +71,15 @@ def upgrade() -> None:
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
@@ -160,9 +170,10 @@ def upgrade() -> None:
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
pass
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
@@ -179,6 +190,8 @@ def upgrade() -> None:
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
@@ -260,7 +273,7 @@ def downgrade() -> None:
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
pass
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")

View File

@@ -1,80 +0,0 @@
"""make categories labels and many to many
Revision ID: 6fc7886d665d
Revises: 3c6531f32351
Create Date: 2025-01-13 18:12:18.029112
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6fc7886d665d"
down_revision = "3c6531f32351"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename persona_category table to persona_label
op.rename_table("persona_category", "persona_label")
# Create the new association table
op.create_table(
"persona__persona_label",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("persona_label_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["persona_label_id"],
["persona_label.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
)
# Copy existing relationships to the new table
op.execute(
"""
INSERT INTO persona__persona_label (persona_id, persona_label_id)
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
"""
)
# Remove the old category_id column from persona table
op.drop_column("persona", "category_id")
def downgrade() -> None:
# Rename persona_label table back to persona_category
op.rename_table("persona_label", "persona_category")
# Add back the category_id column to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"persona_category_id_fkey",
"persona",
"persona_category",
["category_id"],
["id"],
)
# Copy the first label relationship back to the persona table
op.execute(
"""
UPDATE persona
SET category_id = (
SELECT persona_label_id
FROM persona__persona_label
WHERE persona__persona_label.persona_id = persona.id
LIMIT 1
)
"""
)
# Drop the association table
op.drop_table("persona__persona_label")

View File

@@ -1,72 +0,0 @@
"""Add SyncRecord
Revision ID: 97dbb53fa8c8
Revises: 369644546676
Create Date: 2025-01-11 19:39:50.426302
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "97dbb53fa8c8"
down_revision = "be2ab2aa50ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"sync_record",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column(
"sync_type",
sa.Enum(
"DOCUMENT_SET",
"USER_GROUP",
"CONNECTOR_DELETION",
name="synctype",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column(
"sync_status",
sa.Enum(
"IN_PROGRESS",
"SUCCESS",
"FAILED",
"CANCELED",
name="syncstatus",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Add index for fetch_latest_sync_record query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"sync_record",
["entity_id", "sync_type", "sync_start_time"],
)
# Add index for cleanup_sync_records query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_status",
"sync_record",
["entity_id", "sync_type", "sync_status"],
)
def downgrade() -> None:
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
op.drop_table("sync_record")

View File

@@ -1,107 +0,0 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 2f80c6a2550f
Create Date: 2025-01-29 17:00:00.000001
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "2f80c6a2550f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create sub_question table
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
)
# Create sub_query table
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"agent__sub_query__search_doc",
sa.Column(
"sub_query_id",
sa.Integer,
sa.ForeignKey("agent__sub_query.id"),
primary_key=True,
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")
op.drop_table("agent__sub_query__search_doc")
op.drop_table("agent__sub_query")
op.drop_table("agent__sub_question")
op.drop_table("agent__search_metrics")

View File

@@ -1,29 +0,0 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -1,27 +0,0 @@
"""add pinned assistants
Revision ID: aeda5f2df4f6
Revises: c5eae4a75a1b
Create Date: 2025-01-09 16:04:10.770636
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "aeda5f2df4f6"
down_revision = "c5eae4a75a1b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
)
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
def downgrade() -> None:
op.drop_column("user", "pinned_assistants")

View File

@@ -1,38 +0,0 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@@ -1,36 +0,0 @@
"""Add chat_message__standard_answer table
Revision ID: c5eae4a75a1b
Revises: 0f7ff6d75b57
Create Date: 2025-01-15 14:08:49.688998
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c5eae4a75a1b"
down_revision = "0f7ff6d75b57"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_message__standard_answer",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_message.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
)
def downgrade() -> None:
op.drop_table("chat_message__standard_answer")

View File

@@ -1,48 +0,0 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@@ -1,80 +0,0 @@
"""add default slack channel config
Revision ID: eaa3b5593925
Revises: 98a5008d8711
Create Date: 2025-02-03 18:07:56.552526
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eaa3b5593925"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_default column
op.add_column(
"slack_channel_config",
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
)
op.create_index(
"ix_slack_channel_config_slack_bot_id_default",
"slack_channel_config",
["slack_bot_id", "is_default"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
# Create default channel configs for existing slack bots without one
conn = op.get_bind()
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
for slack_bot in slack_bots:
slack_bot_id = slack_bot[0]
existing_default = conn.execute(
sa.text(
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
),
{"bot_id": slack_bot_id},
).fetchone()
if not existing_default:
conn.execute(
sa.text(
"""
INSERT INTO slack_channel_config (
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
) VALUES (
:bot_id, NULL,
'{"channel_name": null, '
'"respond_member_group_list": [], '
'"answer_filters": [], '
'"follow_up_tags": [], '
'"respond_tag_only": true}',
FALSE, TRUE
)
"""
),
{"bot_id": slack_bot_id},
)
def downgrade() -> None:
# Delete default slack channel configs
conn = op.get_bind()
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
# Remove index
op.drop_index(
"ix_slack_channel_config_slack_bot_id_default",
table_name="slack_channel_config",
)
# Remove is_default column
op.drop_column("slack_channel_config", "is_default")

View File

@@ -1,33 +0,0 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1,53 +0,0 @@
"""delete non-search assistants
Revision ID: f5437cc136c5
Revises: eaa3b5593925
Create Date: 2025-02-04 16:17:15.677256
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f5437cc136c5"
down_revision = "eaa3b5593925"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
# Fix: split the statements into multiple op.execute() calls
op.execute(
"""
WITH personas_without_search AS (
SELECT p.id
FROM persona p
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
LEFT JOIN tool t ON pt.tool_id = t.id
GROUP BY p.id
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
)
UPDATE slack_channel_config
SET persona_id = NULL
WHERE is_default = TRUE AND persona_id IN (SELECT id FROM personas_without_search)
"""
)
op.execute(
"""
WITH personas_without_search AS (
SELECT p.id
FROM persona p
LEFT JOIN persona__tool pt ON p.id = pt.persona_id
LEFT JOIN tool t ON pt.tool_id = t.id
GROUP BY p.id
HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0
)
DELETE FROM slack_channel_config
WHERE is_default = FALSE AND persona_id IN (SELECT id FROM personas_without_search)
"""
)

View File

@@ -1,41 +0,0 @@
"""Add time_updated to UserGroup and DocumentSet
Revision ID: fec3db967bf7
Revises: 97dbb53fa8c8
Create Date: 2025-01-12 15:49:02.289100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fec3db967bf7"
down_revision = "97dbb53fa8c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document_set",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"user_group",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_column("user_group", "time_last_modified_by_user")
op.drop_column("document_set", "time_last_modified_by_user")

View File

@@ -1,31 +0,0 @@
"""mapping for anonymous user path
Revision ID: a4f6ee863c47
Revises: 14a83a331951
Create Date: 2025-01-04 14:16:58.697451
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a4f6ee863c47"
down_revision = "14a83a331951"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"tenant_anonymous_user_path",
sa.Column("tenant_id", sa.String(), primary_key=True, nullable=False),
sa.Column("anonymous_user_path", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
sa.UniqueConstraint("anonymous_user_path"),
)
def downgrade() -> None:
op.drop_table("tenant_anonymous_user_path")

View File

@@ -3,10 +3,6 @@ from sqlalchemy.orm import Session
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
@@ -14,7 +10,6 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
@@ -57,20 +52,9 @@ def _get_access_for_documents(
)
doc_id_map = {doc.id: doc for doc in documents}
# Get all sources in one batch
doc_id_to_source_map = get_document_sources(
db_session=db_session,
document_ids=document_ids,
)
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
is_only_censored = (
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)
ext_u_emails = (
set(document.external_user_emails)
@@ -86,11 +70,7 @@ def _get_access_for_documents(
# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Onyx
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public or non_ee_access.is_public or is_only_censored
)
is_public_anywhere = document.is_public or non_ee_access.is_public
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(

View File

@@ -1,7 +1,5 @@
from datetime import datetime
from functools import lru_cache
import jwt
import requests
from fastapi import Depends
from fastapi import HTTPException
@@ -22,7 +20,6 @@ from ee.onyx.server.seeding import get_seed_config
from ee.onyx.utils.secrets import extract_hashed_cookie
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import AuthType
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -121,17 +118,3 @@ async def current_cloud_superuser(
detail="Access denied. User must be a cloud superuser to perform this action.",
)
return user
def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
payload = {
"tenant_id": tenant_id,
# Token does not expire
"iat": datetime.utcnow(), # Issued at time
}
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
def decode_anonymous_user_jwt_token(token: str) -> dict:
return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"])

View File

@@ -32,7 +32,6 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -1,80 +1,24 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
beat_system_tasks as base_beat_system_tasks,
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
beat_task_templates as base_beat_task_templates,
)
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
)
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
)
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
cloud_tasks = generate_cloud_tasks(
beat_system_tasks, beat_task_templates, beat_multiplier
)
return cloud_tasks
ee_tasks_to_schedule = [
{
"name": "autogenerate_usage_report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
},
]
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -6,11 +6,7 @@ from sqlalchemy.orm import Session
from ee.onyx.db.user_group import delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@@ -46,59 +42,15 @@ def monitor_usergroup_taskset(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
try:
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
except Exception as e:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.FAILED,
num_docs_synced=initial_count,
)
raise e
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
rug.reset()

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
@@ -29,12 +15,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
# This is a boolean that determines if anonymous access is public
# Default behavior is to not make the page public and instead add a group
# that contains all the users that we found in Confluence
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
@@ -75,5 +55,3 @@ POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"

View File

@@ -2,7 +2,6 @@ import datetime
from collections.abc import Sequence
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import case
from sqlalchemy import cast
from sqlalchemy import Date
@@ -15,9 +14,6 @@ from onyx.configs.constants import MessageType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserRole
def fetch_query_analytics(
@@ -238,122 +234,3 @@ def fetch_persona_unique_users(
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_message_analytics(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> list[tuple[int, datetime.date]]:
"""
Gets the daily message counts for a specific assistant in the given time range.
"""
query = (
select(
func.count(ChatMessage.id),
cast(ChatMessage.time_sent, Date),
)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(cast(ChatMessage.time_sent, Date))
.order_by(cast(ChatMessage.time_sent, Date))
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_unique_users(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> list[tuple[int, datetime.date]]:
"""
Gets the daily unique user counts for a specific assistant in the given time range.
"""
query = (
select(
func.count(func.distinct(ChatSession.user_id)),
cast(ChatMessage.time_sent, Date),
)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(cast(ChatMessage.time_sent, Date))
.order_by(cast(ChatMessage.time_sent, Date))
)
return [tuple(row) for row in db_session.execute(query).all()]
def fetch_assistant_unique_users_total(
db_session: Session,
assistant_id: int,
start: datetime.datetime,
end: datetime.datetime,
) -> int:
"""
Gets the total number of distinct users who have sent or received messages from
the specified assistant in the given time range.
"""
query = (
select(func.count(func.distinct(ChatSession.user_id)))
.select_from(ChatMessage)
.join(
ChatSession,
ChatMessage.chat_session_id == ChatSession.id,
)
.where(
or_(
ChatMessage.alternate_assistant_id == assistant_id,
ChatSession.persona_id == assistant_id,
),
ChatMessage.time_sent >= start,
ChatMessage.time_sent <= end,
ChatMessage.message_type == MessageType.ASSISTANT,
)
)
result = db_session.execute(query).scalar()
return result if result else 0
# Users can view assistant stats if they created the persona,
# or if they are an admin
def user_can_view_assistant_stats(
db_session: Session, user: User | None, assistant_id: int
) -> bool:
# If user is None and auth is disabled, assume the user is an admin
if user is None or user.role == UserRole.ADMIN:
return True
# Check if the user created the persona
stmt = select(Persona).where(
and_(Persona.id == assistant_id, Persona.user_id == user.id)
)
persona = db_session.execute(stmt).scalar_one_or_none()
return persona is not None

View File

@@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
build_ext_group_name_for_onyx(
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)

View File

@@ -6,12 +6,10 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,10 +60,8 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
)
delete_user__ext_group_for_cc_pair__no_commit(
@@ -87,14 +83,12 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
cc_pair_id=cc_pair_id,
)
)
@@ -112,21 +106,3 @@ def fetch_external_groups_for_user(
User__ExternalUserGroupId.user_id == user_id
)
).all()
def fetch_external_groups_for_user_email_and_group_ids(
db_session: Session,
user_email: str,
group_ids: list[str],
) -> list[User__ExternalUserGroupId]:
user = get_user_by_email(db_session=db_session, email=user_email)
if user is None:
return []
user_id = user.id
user_ext_groups = db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
)
).all()
return list(user_ext_groups)

View File

@@ -2,11 +2,8 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
def make_persona_private(
@@ -15,9 +12,6 @@ def make_persona_private(
group_ids: list[int] | None,
db_session: Session,
) -> None:
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -26,22 +20,11 @@ def make_persona_private(
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
if group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
for group_id in group_ids:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)

View File

@@ -1,138 +1,27 @@
from collections.abc import Sequence
from datetime import datetime
import datetime
from typing import Literal
from sqlalchemy import asc
from sqlalchemy import BinaryExpression
from sqlalchemy import ColumnElement
from sqlalchemy import desc
from sqlalchemy import distinct
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import case
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
def _build_filter_conditions(
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> list[ColumnElement]:
"""
Helper function to build all filter conditions for chat sessions.
Filters by start and end time, feedback type, and any sessions without messages.
start_time: Date from which to filter
end_time: Date to which to filter
feedback_filter: Feedback type to filter by
Returns: List of filter conditions
"""
conditions = []
if start_time is not None:
conditions.append(ChatSession.time_created >= start_time)
if end_time is not None:
conditions.append(ChatSession.time_created <= end_time)
if feedback_filter is not None:
feedback_subq = (
select(ChatMessage.chat_session_id)
.join(ChatMessageFeedback)
.group_by(ChatMessage.chat_session_id)
.having(
case(
(
case(
{literal(feedback_filter == QAFeedbackType.LIKE): True},
else_=False,
),
func.bool_and(ChatMessageFeedback.is_positive),
),
(
case(
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
else_=False,
),
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
),
else_=func.bool_or(ChatMessageFeedback.is_positive)
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
)
)
)
conditions.append(ChatSession.id.in_(feedback_subq))
return conditions
def get_total_filtered_chat_sessions_count(
db_session: Session,
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> int:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
stmt = (
select(func.count(distinct(ChatSession.id)))
.select_from(ChatSession)
.filter(*conditions)
)
return db_session.scalar(stmt) or 0
def get_page_of_chat_sessions(
start_time: datetime | None,
end_time: datetime | None,
db_session: Session,
page_num: int,
page_size: int,
feedback_filter: QAFeedbackType | None = None,
) -> Sequence[ChatSession]:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id)
.filter(*conditions)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
)
stmt = (
select(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),
joinedload(ChatSession.persona),
contains_eager(ChatSession.messages).joinedload(
ChatMessage.chat_message_feedbacks
),
)
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
)
return db_session.scalars(stmt).unique().all()
SortByOptions = Literal["time_sent"]
def fetch_chat_sessions_eagerly_by_time(
start: datetime,
end: datetime,
start: datetime.datetime,
end: datetime.datetime,
db_session: Session,
limit: int | None = 500,
initial_time: datetime | None = None,
initial_time: datetime.datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)

View File

@@ -7,7 +7,6 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
@@ -21,11 +20,10 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return stmt
stmt = stmt.distinct()
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
@@ -48,12 +46,6 @@ def _add_user_filters(
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
# If user is None, this is an anonymous user and we should only show public token_rate_limits
if user is None:
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -111,10 +103,10 @@ def insert_user_group_token_rate_limit(
return token_limit
def fetch_user_group_token_rate_limits_for_user(
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,

View File

@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
)
def validate_object_creation_for_user(
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
@@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(
def construct_document_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document.id)
select(Document)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
@@ -374,9 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
)
db_user_group = UserGroup(name=user_group.name)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -442,108 +440,32 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
_validate_curator_status__no_commit(db_session, [user])
def _validate_curator_relationship_update_requester(
db_session: Session,
user_group_id: int,
user_making_change: User | None = None,
) -> None:
"""
This function validates that the user making the change has the necessary permissions
to update the curator relationship for the target user in the given user group.
"""
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
return
# check if the user making the change is a curator in the group they are changing the curator relationship for
user_making_change_curator_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user_making_change.id,
# only check if the user making the change is a curator if they are a curator
# otherwise, they are a global_curator and can update the curator relationship
# for any group they are a member of
only_curator_groups=user_making_change.role == UserRole.CURATOR,
)
requestor_curator_group_ids = [
group.id for group in user_making_change_curator_groups
]
if user_group_id not in requestor_curator_group_ids:
raise ValueError(
f"user making change {user_making_change.email} is not a curator,"
f" admin, or global_curator for group '{user_group_id}'"
)
def _validate_curator_relationship_update_request(
db_session: Session,
user_group_id: int,
target_user: User,
) -> None:
"""
This function validates that the curator_relationship_update request itself is valid.
"""
if target_user.role == UserRole.ADMIN:
raise ValueError(
f"User '{target_user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
elif target_user.role == UserRole.GLOBAL_CURATOR:
raise ValueError(
f"User '{target_user.email}' is a global_curator and therefore has all "
"permissions of a curator for all groups. If you'd like this user to only "
"have curator permissions for a specific group, you must update their role "
"to BASIC then assign them to be CURATOR in the appropriate groups."
)
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
raise ValueError(
f"This endpoint can only be used to update the curator relationship for "
"users with the CURATOR or BASIC role. \n"
f"Target user: {target_user.email} \n"
f"Target user role: {target_user.role} \n"
)
# check if the target user is in the group they are changing the curator relationship for
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=target_user.id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(
f"target user {target_user.email} is not in group '{user_group_id}'"
)
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not target_user:
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
_validate_curator_relationship_update_request(
if user.role == UserRole.ADMIN:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_group_id=user_group_id,
target_user=target_user,
user_id=set_curator_request.user_id,
only_curator_groups=False,
)
_validate_curator_relationship_update_requester(
db_session=db_session,
user_group_id=user_group_id,
user_making_change=user_making_change,
)
logger.info(
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
f"updating the curator relationship for user={target_user.email} "
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(f"user is not in group '{user_group_id}'")
relationship_to_update = (
db_session.query(User__UserGroup)
@@ -564,7 +486,7 @@ def update_user_curator_relationship(
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [target_user])
_validate_curator_status__no_commit(db_session, [user])
db_session.commit()
@@ -632,10 +554,6 @@ def update_user_group(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
db_session.commit()
return db_user_group
@@ -705,10 +623,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")

View File

@@ -1,4 +0,0 @@
# This is a group that we use to store all the users that we found in Confluence
# Instead of setting a page to public, we just add this group so that the page
# is only accessible to users who have confluence accounts.
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"

View File

@@ -4,8 +4,6 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
"""
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
@@ -13,7 +11,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -25,9 +22,7 @@ _REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
viewspace_permissions = []
for permission_category in space_permissions:
@@ -36,32 +31,14 @@ def _get_server_space_permissions(
permission_category.get("spacePermissions", [])
)
is_public = False
user_names = set()
group_names = set()
for permission in viewspace_permissions:
user_name = permission.get("userName")
if user_name:
if user_name := permission.get("userName"):
user_names.add(user_name)
group_name = permission.get("groupName")
if group_name:
if group_name := permission.get("groupName"):
group_names.add(group_name)
# It seems that if anonymous access is turned on for the site and space,
# then the space is publicly accessible.
# For confluence server, we make a group that contains all users
# that exist in confluence and then just add that group to the space permissions
# if anonymous access is turned on for the site and space or we set is_public = True
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
# that we can support confluence server deployments that want anonymous access
# to be public (we cant test this because its paywalled)
if user_name is None and group_name is None:
# Defaults to False
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
is_public = True
else:
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
user_emails = set()
for user_name in user_names:
user_email = get_user_email_from_username__server(confluence_client, user_name)
@@ -70,17 +47,14 @@ def _get_server_space_permissions(
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_public,
# TODO: Check if the space is publicly accessible
# Currently, we assume the space is not public
# We need to check if anonymous access is turned on for the site and space
# This information is paywalled so it remains unimplemented
is_public=False,
)
@@ -160,7 +134,7 @@ def _get_space_permissions(
def _extract_read_access_restrictions(
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
) -> tuple[set[str], set[str]]:
) -> ExternalAccess | None:
"""
Converts a page's restrictions dict into an ExternalAccess object.
If there are no restrictions, then return None
@@ -203,62 +177,25 @@ def _extract_read_access_restrictions(
group["name"] for group in read_access_group_jsons if group.get("name")
]
return set(read_access_user_emails), set(read_access_group_names)
def _get_all_page_restrictions(
confluence_client: OnyxConfluence,
perm_sync_data: dict[str, Any],
) -> ExternalAccess | None:
"""
This function gets the restrictions for a page by taking the intersection
of the page's restrictions and the restrictions of all the ancestors
of the page.
If the page/ancestor has no restrictions, then it is ignored (no intersection).
If no restrictions are found anywhere, then return None, indicating that the page
should inherit the space's restrictions.
"""
found_user_emails: set[str] = set()
found_group_names: set[str] = set()
found_user_emails, found_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=perm_sync_data.get("restrictions", {}),
)
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
for ancestor in ancestors:
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=ancestor.get("restrictions", {}),
)
if not ancestor_user_emails and not ancestor_group_names:
# This ancestor has no restrictions, so it has no effect on
# the page's restrictions, so we ignore it
continue
found_user_emails.intersection_update(ancestor_user_emails)
found_group_names.intersection_update(ancestor_group_names)
# If there are no restrictions found, then the page
# inherits the space's restrictions so return None
if not found_user_emails and not found_group_names:
is_space_public = read_access_user_emails == [] and read_access_group_names == []
if is_space_public:
return None
return ExternalAccess(
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
external_user_emails=set(read_access_user_emails),
external_user_group_ids=set(read_access_group_names),
# there is no way for a page to be individually public if the space isn't public
is_public=False,
)
def _fetch_all_page_restrictions(
def _fetch_all_page_restrictions_for_space(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -267,21 +204,15 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
)
if restrictions := _get_all_page_restrictions(
restrictions = _extract_read_access_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
)
if restrictions:
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
@@ -342,7 +273,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -365,23 +296,14 @@ def confluence_doc_sync(
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions(
return _fetch_all_page_restrictions_for_space(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -1,11 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -14,8 +14,6 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@@ -32,20 +30,12 @@ def _build_group_member_email_map(
)
if not email:
# If we still don't have an email, skip this user
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails
@@ -63,7 +53,6 @@ def confluence_group_sync(
confluence_client=confluence_client,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()
for group_id, group_member_emails in group_member_email_map.items():
onyx_groups.append(
ExternalUserGroup(
@@ -71,15 +60,5 @@ def confluence_group_sync(
user_emails=list(group_member_emails),
)
)
all_found_emails.update(group_member_emails)
# This is so that when we find a public confleunce server page, we can
# give access to all users only in if they have an email in Confluence
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
all_found_group = ExternalUserGroup(
id=ALL_CONF_EMAILS_GROUP_NAME,
user_emails=list(all_found_emails),
)
onyx_groups.append(all_found_group)
return onyx_groups

View File

@@ -6,7 +6,6 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -15,7 +14,6 @@ logger = setup_logger()
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
gmail_connector: GmailConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -25,14 +23,12 @@ def _get_slim_doc_generator(
)
return gmail_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -43,19 +39,11 @@ def gmail_doc_sync(
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback
)
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,7 +10,6 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -21,7 +20,6 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
google_drive_connector: GoogleDriveConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -31,9 +29,7 @@ def _get_slim_doc_generator(
)
return google_drive_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,
start=start_time, end=current_time.timestamp()
)
@@ -46,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -71,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -121,18 +120,15 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
external_user_group_ids=group_emails,
is_public=public,
)
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -150,12 +146,6 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -1,127 +1,16 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
E.g. {
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
drive_service = get_drive_service(
google_drive_connector.creds,
google_drive_connector.primary_admin_email,
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
"""
This gets all the group emails.
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
) -> dict[str, set[str]]:
"""
This maps group emails to their member emails.
"""
group_to_member_map: dict[str, set[str]] = {}
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.add(member["email"])
group_to_member_map[group_email] = group_member_emails
return group_to_member_map
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
# Convert all drive member definitions to onyx groups
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
all_member_emails: set[str] = user_emails
for group_email in group_emails:
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(all_member_emails),
)
)
# Convert all group member definitions to onyx groups
for group_email, member_emails in group_email_to_member_emails_map.items():
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(member_emails),
)
)
return onyx_groups
def gdrive_group_sync(
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -130,23 +19,34 @@ def gdrive_group_sync(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
onyx_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Get all group emails
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
if not group_member_emails:
continue
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
)
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
)
return onyx_groups

View File

@@ -1,84 +0,0 @@
from collections.abc import Callable
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}
# NOTE: This is only called if ee is enabled.
def _post_query_chunk_censoring(
chunks: list[InferenceChunk],
user: User | None,
) -> list[InferenceChunk]:
"""
This function checks all chunks to see if they need to be sent to a censoring
function. If they do, it sends them to the censoring function and returns the
censored chunks. If they don't, it returns the original chunks.
"""
if user is None:
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
for chunk in chunks:
# Separate out chunks that require permission post-processing by source
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:
logger.exception(
f"Failed to censor chunks for source {source} so throwing out all"
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep

View File

@@ -1,226 +0,0 @@
import time
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
from ee.onyx.external_permissions.salesforce.utils import (
get_any_salesforce_client_for_doc_id,
)
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
from ee.onyx.external_permissions.salesforce.utils import (
get_salesforce_user_id_from_email,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Types
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
# NOTE: Used for testing timing
def _get_dummy_object_access_map(
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
) -> dict[str, bool]:
time.sleep(0.15)
# return {object_id: True for object_id in object_ids}
import random
return {object_id: random.choice([True, False]) for object_id in object_ids}
def _get_objects_access_for_user_email_from_salesforce(
object_ids: set[str],
user_email: str,
chunks: list[InferenceChunk],
) -> dict[str, bool] | None:
"""
This function wraps the salesforce call as we may want to change how this
is done in the future. (E.g. replace it with the above function)
"""
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
return None
# This is the only query that is not cached in the function
# so it takes 0.1-0.2 seconds total
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
return object_id_to_access
def _extract_salesforce_object_id_from_url(url: str) -> str:
return url.split("/")[-1]
def _get_object_ranges_for_chunk(
chunk: InferenceChunk,
) -> dict[str, list[ContentRange]]:
"""
Given a chunk, return a dictionary of salesforce object ids and the content ranges
for that object id in the current chunk
"""
if chunk.source_links is None:
return {}
object_ranges: dict[str, list[ContentRange]] = {}
end_index = None
descending_source_links = sorted(
chunk.source_links.items(), key=lambda x: x[0], reverse=True
)
for start_index, url in descending_source_links:
object_id = _extract_salesforce_object_id_from_url(url)
if object_id not in object_ranges:
object_ranges[object_id] = []
object_ranges[object_id].append((start_index, end_index))
end_index = start_index
return object_ranges
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
"""
Create a copy of the unfiltered chunk where potentially sensitive content is removed
to be added later if the user has access to each of the sub-objects
"""
empty_censored_chunk = InferenceChunk(
**uncensored_chunk.model_dump(),
)
empty_censored_chunk.content = ""
empty_censored_chunk.blurb = ""
empty_censored_chunk.source_links = {}
return empty_censored_chunk
def _update_censored_chunk(
censored_chunk: InferenceChunk,
uncensored_chunk: InferenceChunk,
content_range: ContentRange,
) -> InferenceChunk:
"""
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
"""
start_index, end_index = content_range
# Update the content of the filtered chunk
permitted_content = uncensored_chunk.content[start_index:end_index]
permitted_section_start_index = len(censored_chunk.content)
censored_chunk.content = permitted_content + censored_chunk.content
# Update the source links of the filtered chunk
if uncensored_chunk.source_links is not None:
if censored_chunk.source_links is None:
censored_chunk.source_links = {}
link_content = uncensored_chunk.source_links[start_index]
censored_chunk.source_links[permitted_section_start_index] = link_content
# Update the blurb of the filtered chunk
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
return censored_chunk
# TODO: Generalize this to other sources
def censor_salesforce_chunks(
chunks: list[InferenceChunk],
user_email: str,
# This is so we can provide a mock access map for testing
access_map: dict[str, bool] | None = None,
) -> list[InferenceChunk]:
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
# (doc_id, chunk_id) -> chunk
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
# keep track of all object ids that we have seen to make it easier to get
# the access for these object ids
object_ids: set[str] = set()
for chunk in chunks:
chunk_key = (chunk.document_id, chunk.chunk_id)
# create a dictionary to quickly look up the unfiltered chunk
uncensored_chunks[chunk_key] = chunk
# for each chunk, get a dictionary of object ids and the content ranges
# for that object id in the current chunk
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
for object_id, ranges in object_ranges_for_chunk.items():
object_ids.add(object_id)
for start_index, end_index in ranges:
object_to_content_map.setdefault(object_id, []).append(
(chunk_key, (start_index, end_index))
)
# This is so we can provide a mock access map for testing
if access_map is None:
access_map = _get_objects_access_for_user_email_from_salesforce(
object_ids=object_ids,
user_email=user_email,
chunks=chunks,
)
if access_map is None:
# If the user is not found in Salesforce, access_map will be None
# so we should just return an empty list because no chunks will be
# censored
return []
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
for object_id, content_list in object_to_content_map.items():
# if the user does not have access to the object, or the object is not in the
# access_map, do not include its content in the filtered chunks
if not access_map.get(object_id, False):
continue
# if we got this far, the user has access to the object so we can create or update
# the filtered chunk(s) for this object
# NOTE: we only create a censored chunk if the user has access to some
# part of the chunk
for chunk_key, content_range in content_list:
if chunk_key not in censored_chunks:
censored_chunks[chunk_key] = _create_empty_censored_chunk(
uncensored_chunks[chunk_key]
)
uncensored_chunk = uncensored_chunks[chunk_key]
censored_chunk = _update_censored_chunk(
censored_chunk=censored_chunks[chunk_key],
uncensored_chunk=uncensored_chunk,
content_range=content_range,
)
censored_chunks[chunk_key] = censored_chunk
return list(censored_chunks.values())
# NOTE: This is not used anywhere.
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,
# Maybe make a function that adds a salesforce prefix to the group ids
group_ids=list(object_ids),
)
external_group_ids = {group.external_user_group_id for group in external_groups}
return {group_id: group_id in external_group_ids for group_id in object_ids}

View File

@@ -1,177 +0,0 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
logger = setup_logger()
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
def get_any_salesforce_client_for_doc_id(
db_session: Session, doc_id: str
) -> Salesforce:
"""
We create a salesforce client for the first cc_pair for the first doc_id where
salesforce censoring is enabled. After that we just cache and reuse the same
client for all queries.
We do this to reduce the number of postgres queries we make at query time.
This may be problematic if they are using multiple cc_pairs for salesforce.
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _ANY_SALESFORCE_CLIENT
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
# If we don't know their user_id, we don't store anything in the cache.
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
def get_salesforce_user_id_from_email(
sf_client: Salesforce,
user_email: str,
) -> str | None:
"""
We cache this so we don't have to query Salesforce for every query and salesforce
user IDs never change.
Memory usage is fine because we just store 2 small strings per user.
If the email is not in the cache, we check the local salesforce database for the info.
If the user is not found in the local salesforce database, we query Salesforce.
Whatever we get back from Salesforce is added to the database.
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
we don't query Salesforce again (which is slow) but we still check the local salesforce
database every query until a user id is found. This is acceptable because the query time
is quite fast.
If a user_id is created in Salesforce, it will be added to the local salesforce database
next time the connector is run. Then that value will be found in this function and cached.
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id
_MAX_RECORD_IDS_PER_QUERY = 200
def get_objects_access_for_user_id(
salesforce_client: Salesforce,
user_id: str,
record_ids: list[str],
) -> dict[str, bool]:
"""
Salesforce has a limit of 200 record ids per query. So we just truncate
the list of record ids to 200. We only ever retrieve 50 chunks at a time
so this should be fine (unlikely that we retrieve all 50 chunks contain
4 unique objects).
If we decide this isn't acceptable we can use multiple queries but they
should be in parallel so query time doesn't get too long.
"""
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
access_query = f"""
SELECT RecordId, HasReadAccess
FROM UserRecordAccess
WHERE RecordId IN ({record_ids_str})
AND UserId = '{user_id}'
"""
result = salesforce_client.query_all(access_query)
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
# NOTE: This is not used anywhere.
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
"""
Uses a document id to get the cc_pair that indexed that document and uses the credentials
for that cc_pair to create a Salesforce client.
Problems:
- There may be multiple cc_pairs for a document, and we don't know which one to use.
- right now we just use the first one
- Building a new Salesforce client for each document is slow.
- Memory usage could be an issue as we build these dictionaries.
"""
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]

View File

@@ -7,7 +7,6 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -15,12 +14,12 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
@@ -32,14 +31,6 @@ def _get_slack_document_ids_and_channels(
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map
@@ -123,7 +114,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -136,7 +127,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -8,20 +8,15 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]
@@ -76,7 +71,4 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return (
source_type in DOC_PERMISSIONS_FUNC_MAP
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
)
return source_type in DOC_PERMISSIONS_FUNC_MAP

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -42,7 +40,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -65,7 +62,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_auth_router_with_prefix(
include_router_with_global_prefix_prepended(
application,
create_onyx_oauth_router(
oauth_client,
@@ -77,26 +74,22 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_auth_router_with_prefix(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_auth_router_with_prefix(
include_router_with_global_prefix_prepended(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
@@ -104,20 +97,19 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_auth_router_with_prefix(
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
include_auth_router_with_prefix(
application,
saml_router,
)
include_router_with_global_prefix_prepended(application, saml_router)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)

View File

@@ -80,7 +80,7 @@ def oneoff_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig,
slack_channel_config: SlackChannelConfig | None,
prompt: Prompt | None,
logger: OnyxLoggingAdapter,
client: WebClient,
@@ -94,10 +94,13 @@ def _handle_standard_answers(
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_channel_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_channel_config.standard_answer_categories
slack_channel_config.standard_answer_categories if slack_channel_config else []
)
configured_standard_answers = set(
[
@@ -147,9 +150,9 @@ def _handle_standard_answers(
db_session=db_session,
description="",
user_id=None,
persona_id=(
slack_channel_config.persona.id if slack_channel_config.persona else 0
),
persona_id=slack_channel_config.persona.id
if slack_channel_config.persona
else 0,
onyxbot_flow=True,
slack_thread_id=slack_thread_id,
)
@@ -179,7 +182,7 @@ def _handle_standard_answers(
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
chat_message = create_new_chat_message(
_ = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
@@ -188,13 +191,8 @@ def _handle_standard_answers(
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=False,
commit=True,
)
# attach the standard answers to the chat message
chat_message.standard_answers = [
standard_answer for standard_answer, _ in matching_standard_answers
]
db_session.commit()
update_emote_react(
emoji=DANSWER_REACT_EMOJI,

View File

@@ -1,24 +1,17 @@
import datetime
from collections import defaultdict
from typing import List
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.analytics import fetch_assistant_message_analytics
from ee.onyx.db.analytics import fetch_assistant_unique_users
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
from ee.onyx.db.analytics import fetch_onyxbot_analytics
from ee.onyx.db.analytics import fetch_per_user_query_analytics
from ee.onyx.db.analytics import fetch_persona_message_analytics
from ee.onyx.db.analytics import fetch_persona_unique_users
from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine import get_session
from onyx.db.models import User
@@ -198,74 +191,3 @@ def get_persona_unique_users(
)
)
return unique_user_counts
class AssistantDailyUsageResponse(BaseModel):
date: datetime.date
total_messages: int
total_unique_users: int
class AssistantStatsResponse(BaseModel):
daily_stats: List[AssistantDailyUsageResponse]
total_messages: int
total_unique_users: int
@router.get("/assistant/{assistant_id}/stats")
def get_assistant_stats(
assistant_id: int,
start: datetime.datetime | None = None,
end: datetime.datetime | None = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantStatsResponse:
"""
Returns daily message and unique user counts for a user's assistant,
along with the overall total messages and total distinct users.
"""
start = start or (
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
)
end = end or datetime.datetime.utcnow()
if not user_can_view_assistant_stats(db_session, user, assistant_id):
raise HTTPException(
status_code=403, detail="Not allowed to access this assistant's stats."
)
# Pull daily usage from the DB calls
messages_data = fetch_assistant_message_analytics(
db_session, assistant_id, start, end
)
unique_users_data = fetch_assistant_unique_users(
db_session, assistant_id, start, end
)
# Map each day => (messages, unique_users).
daily_messages_map = {date: count for count, date in messages_data}
daily_unique_users_map = {date: count for count, date in unique_users_data}
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
# Merge both sets of metrics by date
daily_results: list[AssistantDailyUsageResponse] = []
for date in sorted(all_dates):
daily_results.append(
AssistantDailyUsageResponse(
date=date,
total_messages=daily_messages_map.get(date, 0),
total_unique_users=daily_unique_users_map.get(date, 0),
)
)
# Now pull a single total distinct user count across the entire time range
total_msgs = sum(d.total_messages for d in daily_results)
total_users = fetch_assistant_unique_users_total(
db_session, assistant_id, start, end
)
return AssistantStatsResponse(
daily_stats=daily_results,
total_messages=total_msgs,
total_unique_users=total_users,
)

View File

@@ -2,17 +2,15 @@ import logging
from collections.abc import Awaitable
from collections.abc import Callable
import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -24,11 +22,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = (
_get_tenant_id_from_request(request, logger)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return await call_next(request)
@@ -37,48 +35,27 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
raise
async def _get_tenant_id_from_request(
request: Request, logger: logging.LoggerAdapter
) -> str:
"""
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
# First check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id:
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
# Check for cookie-based auth
token = request.cookies.get("fastapiusersauth")
if not token:
return POSTGRES_DEFAULT_SCHEMA
try:
# Look up token data in Redis
payload = jwt.decode(
token,
USER_AUTH_SECRET,
audience=["fastapi-users:auth"],
algorithms=["HS256"],
)
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
# Since token_data.get() can return None, ensure we have a string
# Since payload.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
@@ -88,18 +65,11 @@ async def _get_tenant_id_from_request(
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
return tenant_id
finally:
if tenant_id:
return tenant_id
# As a final step, check for explicit tenant_id cookie
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
return tenant_id_cookie
# If we've reached this point, return the default schema
except jwt.InvalidTokenError:
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -1,7 +1,5 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
@@ -12,29 +10,11 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -82,7 +62,14 @@ class SlackOAuth:
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
@@ -90,14 +77,10 @@ class SlackOAuth:
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
@@ -119,151 +102,82 @@ class SlackOAuth:
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# Work in progress
# class ConfluenceCloudOAuth:
# """work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
# class OAuthSession(BaseModel):
# """Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
# email: str
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
# CONFLUENCE_OAUTH_SCOPE = (
# "read:confluence-props%20"
# "read:confluence-content.all%20"
# "read:confluence-content.summary%20"
# "read:confluence-content.permission%20"
# "read:confluence-user%20"
# "read:confluence-groups%20"
# "readonly:content.attachment:confluence"
# )
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
# # eventually for Confluence Data Center
# # oauth_url = (
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# # f"&redirect_uri={redirectme_uri}"
# # )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
# @classmethod
# def generate_oauth_url(cls, state: str) -> str:
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
# @classmethod
# def generate_dev_oauth_url(cls, state: str) -> str:
# """dev mode workaround for localhost testing
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
# """
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
# @classmethod
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# url = (
# "https://auth.atlassian.com/authorize"
# f"?audience=api.atlassian.com"
# f"&client_id={cls.CLIENT_ID}"
# f"&redirect_uri={redirect_uri}"
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
# f"&state={state}"
# "&response_type=code"
# "&prompt=consent"
# )
# return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
# @classmethod
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
# """Temporary state to store in redis. to be looked up on auth response.
# Returns a json string.
# """
# session = ConfluenceCloudOAuth.OAuthSession(
# email=email, redirect_on_success=redirect_on_success
# )
# return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
# @classmethod
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
# return session
@router.post("/prepare-authorization-request")
@@ -278,26 +192,17 @@ def prepare_authorization_request(
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
@@ -305,6 +210,8 @@ def prepare_authorization_request(
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
# elif connector == DocumentSource.GOOGLE_DRIVE:
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
@@ -316,7 +223,6 @@ def prepare_authorization_request(
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
@@ -515,117 +421,3 @@ def handle_slack_oauth_callback(
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -179,7 +179,6 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
@@ -302,7 +301,6 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(

View File

@@ -57,9 +57,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -74,8 +71,6 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):
@@ -125,12 +120,9 @@ class OneShotQARequest(ChunkContext):
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generating an AI response to the search query
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:

View File

@@ -196,8 +196,6 @@ def get_answer_stream(
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@@ -1,23 +1,19 @@
import csv
import io
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Literal
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
@@ -27,15 +23,257 @@ from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
flow_type: SessionType
conversation_length: int
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}
def determine_flow_type(chat_session: ChatSession) -> SessionType:
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
def fetch_and_process_chat_session_history_minimal(
db_session: Session,
start: datetime,
end: datetime,
feedback_filter: QAFeedbackType | None = None,
limit: int | None = 500,
) -> list[ChatSessionMinimal]:
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
minimal_sessions = []
for chat_session in chat_sessions:
if not chat_session.messages:
continue
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
has_positive_feedback = any(
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
has_negative_feedback = any(
not feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
"mixed"
if has_positive_feedback and has_negative_feedback
else QAFeedbackType.LIKE
if has_positive_feedback
else QAFeedbackType.DISLIKE
if has_negative_feedback
else None
)
if feedback_filter:
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
continue
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
continue
flow_type = determine_flow_type(chat_session)
minimal_sessions.append(
ChatSessionMinimal(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=feedback_type,
flow_type=flow_type,
conversation_length=len(
[
m
for m in chat_session.messages
if m.message_type != MessageType.SYSTEM
]
),
)
)
return minimal_sessions
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
@@ -81,7 +319,7 @@ def snapshot_from_chat_session(
except RuntimeError:
return None
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
flow_type = determine_flow_type(chat_session)
return ChatSessionSnapshot(
id=chat_session.id,
@@ -133,38 +371,22 @@ def get_user_chat_sessions(
@router.get("/admin/chat-session-history")
def get_chat_session_history(
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
start: datetime | None = None,
end: datetime | None = None,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
) -> list[ChatSessionMinimal]:
return fetch_and_process_chat_session_history_minimal(
db_session=db_session,
start_time=start_time,
end_time=end_time,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
feedback_filter=feedback_type,
)
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
db_session=db_session,
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(

View File

@@ -1,218 +0,0 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from onyx.auth.users import get_display_email
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
id: int
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
id=message.id,
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | None
flow_type: SessionType
conversation_length: int
@classmethod
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
list_of_message_feedbacks = [
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
]
session_feedback_type = None
if list_of_message_feedbacks:
if all(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.LIKE
elif not any(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.DISLIKE
else:
session_feedback_type = QAFeedbackType.MIXED
return cls(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=session_feedback_type,
flow_type=SessionType.SLACK
if chat_session.onyxbot_flow
else SessionType.CHAT,
conversation_length=len(
[
message
for message in chat_session.messages
if message.message_type != MessageType.SYSTEM
]
),
)
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}

View File

@@ -13,8 +13,9 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
from ee.onyx.db.usage_export import write_usage_report
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
from onyx.auth.schemas import UserStatus
from onyx.configs.constants import FileOrigin
from onyx.db.users import get_all_users
from onyx.db.users import list_users
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
@@ -83,15 +84,15 @@ def generate_user_report(
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
) as temp_file:
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["user_id", "is_active"])
csvwriter.writerow(["user_id", "status"])
users = get_all_users(db_session)
users = list_users(db_session)
for user in users:
user_skeleton = UserSkeleton(
user_id=str(user.id),
is_active=user.is_active,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
temp_file.seek(0)
file_store.save_file(

View File

@@ -4,6 +4,8 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserStatus
class FlowType(str, Enum):
CHAT = "chat"
@@ -20,7 +22,7 @@ class ChatMessageSkeleton(BaseModel):
class UserSkeleton(BaseModel):
user_id: str
is_active: bool
status: UserStatus
class UsageReportMetadata(BaseModel):

View File

@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[PersonaUpsertRequest] | None = None
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
@@ -128,7 +128,7 @@ def _seed_llms(
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:

View File

@@ -1,59 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import TenantAnonymousUserPath
def get_anonymous_user_path(tenant_id: str, db_session: Session) -> str | None:
result = db_session.execute(
select(TenantAnonymousUserPath).where(
TenantAnonymousUserPath.tenant_id == tenant_id
)
)
result_scalar = result.scalar_one_or_none()
if result_scalar:
return result_scalar.anonymous_user_path
else:
return None
def modify_anonymous_user_path(
tenant_id: str, anonymous_user_path: str, db_session: Session
) -> None:
# Enforce lowercase path at DB operation level
anonymous_user_path = anonymous_user_path.lower()
existing_entry = (
db_session.query(TenantAnonymousUserPath).filter_by(tenant_id=tenant_id).first()
)
if existing_entry:
existing_entry.anonymous_user_path = anonymous_user_path
else:
new_entry = TenantAnonymousUserPath(
tenant_id=tenant_id, anonymous_user_path=anonymous_user_path
)
db_session.add(new_entry)
db_session.commit()
def get_tenant_id_for_anonymous_user_path(
anonymous_user_path: str, db_session: Session
) -> str | None:
result = db_session.execute(
select(TenantAnonymousUserPath).where(
TenantAnonymousUserPath.anonymous_user_path == anonymous_user_path
)
)
result_scalar = result.scalar_one_or_none()
if result_scalar:
return result_scalar.tenant_id
else:
return None
def validate_anonymous_user_path(path: str) -> None:
if not path or "/" in path or not path.replace("-", "").isalnum():
raise ValueError("Invalid path. Use only letters, numbers, and hyphens.")

View File

@@ -3,126 +3,35 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_tenant(tenant_id=None) as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_tenant(tenant_id=None) as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_tenant(tenant_id=None) as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
@@ -194,7 +103,7 @@ async def impersonate_user(
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
token = await get_jwt_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
@@ -205,48 +114,3 @@ async def impersonate_user(
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -46,7 +46,6 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
"""
Send a request to the control service to register the number of users for a tenant.
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")

View File

@@ -39,12 +39,3 @@ class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
class TenantDeletionPayload(BaseModel):
tenant_id: str
email: str
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None

View File

@@ -15,7 +15,6 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
@@ -24,7 +23,6 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
@@ -86,8 +84,7 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
@@ -188,7 +185,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
@@ -324,26 +320,3 @@ async def submit_to_hubspot(
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@@ -68,11 +68,3 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()

View File

@@ -5,7 +5,7 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
@@ -51,10 +51,8 @@ def get_group_token_limit_settings(
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
db_session=db_session,
group_id=group_id,
user=user,
for token_rate_limit in fetch_user_group_token_rate_limits(
db_session, group_id, user
)
]

View File

@@ -83,7 +83,7 @@ def patch_user_group(
def set_user_curator(
user_group_id: int,
set_curator_request: SetCuratorRequest,
user: User | None = Depends(current_curator_or_admin_user),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
@@ -91,7 +91,6 @@ def set_user_curator(
db_session=db_session,
user_group_id=user_group_id,
set_curator_request=set_curator_request,
user_making_change=user,
)
except ValueError as e:
logger.error(f"Error setting user curator: {e}")

View File

@@ -58,7 +58,6 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -10,7 +10,6 @@ logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
@@ -25,10 +24,15 @@ posthog = Posthog(
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:
"""Capture and send an event to PostHog, flushing immediately."""
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
print("API KEY", POSTHOG_API_KEY)
print("HOST", POSTHOG_HOST)
try:
posthog.capture(distinct_id, event, properties)
print(type(distinct_id))
print(type(event))
print(type(properties))
response = posthog.capture(distinct_id, event, properties)
posthog.flush()
print(response)
except Exception as e:
logger.error(f"Error capturing PostHog event: {e}")
logger.error(f"Error capturing Posthog event: {e}")

View File

@@ -28,9 +28,3 @@ class EmbeddingModelTextType:
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:
CUDA = "cuda"
MAC_MPS = "mps"
NONE = "none"

View File

@@ -1,6 +1,5 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Optional
@@ -12,7 +11,6 @@ import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
@@ -321,8 +319,9 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
@@ -331,17 +330,8 @@ async def embed_text(
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
start = time.monotonic()
total_chars = 0
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
logger.debug(f"Using cloud provider {provider_type} for embedding")
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
@@ -373,19 +363,8 @@ async def embed_text(
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
)
elif model_name is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
logger.debug(f"Using local model {model_name} for embedding")
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
@@ -403,25 +382,13 @@ async def embed_text(
for embedding in embeddings_vectors
]
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
logger.info(
f"event=embedding_model "
f"texts={len(texts)} "
f"chars={total_chars} "
f"model={model_name} "
f"gpu={gpu_type} "
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings
@@ -468,20 +435,12 @@ async def litellm_rerank(
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
embed_request: EmbedRequest,
) -> EmbedResponse:
return await process_embed_request(embed_request, request.app.state.gpu_type)
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
embed_request: EmbedRequest,
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
if not all(embed_request.texts):
elif not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
@@ -504,7 +463,6 @@ async def process_embed_request(
api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
@@ -513,12 +471,9 @@ async def process_embed_request(
detail=str(e),
)
except Exception as e:
logger.exception(
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
)
raise HTTPException(
status_code=500, detail=f"Error during embedding process: {e}"
)
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
@router.post("/cross-encoder-scores")

View File

@@ -16,7 +16,6 @@ from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_ONLY
@@ -45,7 +44,6 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():
@@ -59,10 +57,12 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
gpu_type = get_gpu_type()
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
app.state.gpu_type = gpu_type
if torch.cuda.is_available():
logger.notice("CUDA GPU is available")
elif torch.backends.mps.is_available():
logger.notice("Mac MPS is available")
else:
logger.notice("GPU is not available, using CPU")
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")

View File

@@ -1,9 +1,7 @@
import torch
from fastapi import APIRouter
from fastapi import Response
from model_server.constants import GPUStatus
from model_server.utils import get_gpu_type
router = APIRouter(prefix="/api")
@@ -13,7 +11,10 @@ async def healthcheck() -> Response:
@router.get("/gpu-status")
async def route_gpu_status() -> dict[str, bool | str]:
gpu_type = get_gpu_type()
gpu_available = gpu_type != GPUStatus.NONE
return {"gpu_available": gpu_available, "type": gpu_type}
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():
return {"gpu_available": True, "type": "mps"}
else:
return {"gpu_available": False, "type": "none"}

View File

@@ -8,9 +8,6 @@ from typing import Any
from typing import cast
from typing import TypeVar
import torch
from model_server.constants import GPUStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -61,12 +58,3 @@ def simple_log_function_time(
return cast(F, wrapped_sync_func)
return decorator
def get_gpu_type() -> str:
if torch.cuda.is_available():
return GPUStatus.CUDA
if torch.backends.mps.is_available():
return GPUStatus.MAC_MPS
return GPUStatus.NONE

View File

@@ -19,9 +19,6 @@ def prefix_external_group(ext_group_name: str) -> str:
return f"external_group:{ext_group_name}"
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
"""
External groups may collide across sources, every source needs its own prefix.
NOTE: the name is lowercased to handle case sensitivity for group names
"""
return f"{source.value}_{ext_group_name}".lower()
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -1,97 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "tool_call"
)
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(_unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -1,35 +0,0 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
_unused: bool = True
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
):
pass

View File

@@ -1,64 +0,0 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
"basic_response",
response_part,
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -1,21 +0,0 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]

View File

@@ -1,31 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
"""
LangGraph edge to send a sub-question to the expanded retrieval.
"""
edge_start_time = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
),
)

Some files were not shown because too many files have changed in this diff Show More