mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
152 Commits
cloud_debu
...
feature/im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87ea5f87e1 | ||
|
|
1470b7e038 | ||
|
|
bf78fb79f8 | ||
|
|
d972a78f45 | ||
|
|
50131ba22c | ||
|
|
439217317f | ||
|
|
c55de28423 | ||
|
|
91e32e801d | ||
|
|
2ae91f0f2b | ||
|
|
d40fd82803 | ||
|
|
97a963b4bf | ||
|
|
7f6ef1ff57 | ||
|
|
d98746b988 | ||
|
|
a76f1b4c1b | ||
|
|
4c4ff46fe3 | ||
|
|
0f9842064f | ||
|
|
d7bc32c0ec | ||
|
|
1f48de9731 | ||
|
|
a22d02ff70 | ||
|
|
dcfc621a66 | ||
|
|
eac73a1bf1 | ||
|
|
717560872f | ||
|
|
ce2572134c | ||
|
|
02f72a5c86 | ||
|
|
eb916df139 | ||
|
|
fafad5e119 | ||
|
|
a314a08309 | ||
|
|
4ce24d68f7 | ||
|
|
a95f4298ad | ||
|
|
7cd76ec404 | ||
|
|
5b5c1166ca | ||
|
|
d9e9c6973d | ||
|
|
91903141cd | ||
|
|
e329b63b89 | ||
|
|
71c2559ea9 | ||
|
|
ceb34a41d9 | ||
|
|
82eab9d704 | ||
|
|
2b8d3a6ef5 | ||
|
|
4fb129e77b | ||
|
|
f16ca1b735 | ||
|
|
e3b2c9d944 | ||
|
|
6c9c25642d | ||
|
|
2862d8bbd3 | ||
|
|
143be6a524 | ||
|
|
c2444a5cff | ||
|
|
7f8194798a | ||
|
|
e3947e4b64 | ||
|
|
98005510ad | ||
|
|
ca54bd0b21 | ||
|
|
d26f8ce852 | ||
|
|
c8090ab75b | ||
|
|
e100a5e965 | ||
|
|
ddec239fef | ||
|
|
e83542f572 | ||
|
|
8750f14647 | ||
|
|
27699c8216 | ||
|
|
6fcd712a00 | ||
|
|
b027a08698 | ||
|
|
1db778baa8 | ||
|
|
f895e5f7d0 | ||
|
|
2fc58252f4 | ||
|
|
371d1ccd8f | ||
|
|
7fb92d42a0 | ||
|
|
af2061c4db | ||
|
|
ffec19645b | ||
|
|
67d2c86250 | ||
|
|
6c018cb53f | ||
|
|
62302e3faf | ||
|
|
0460531c72 | ||
|
|
6af07a888b | ||
|
|
ea75f5cd5d | ||
|
|
b92c183022 | ||
|
|
c191e23256 | ||
|
|
66f9124135 | ||
|
|
8f0fb70bbf | ||
|
|
ef5e5c80bb | ||
|
|
03acb6587a | ||
|
|
d1ec72b5e5 | ||
|
|
3b214133a8 | ||
|
|
2232702e99 | ||
|
|
8108ff0a4b | ||
|
|
f64e78e986 | ||
|
|
08312a4394 | ||
|
|
92add655e0 | ||
|
|
d64464ca7c | ||
|
|
ccd3983802 | ||
|
|
240f3e4fff | ||
|
|
1291b3d930 | ||
|
|
d05f1997b5 | ||
|
|
aa2e2a62b9 | ||
|
|
174e5968f8 | ||
|
|
1f27606e17 | ||
|
|
60355b84c1 | ||
|
|
680ab9ea30 | ||
|
|
c2447dbb1c | ||
|
|
52bad522f8 | ||
|
|
63e5e58313 | ||
|
|
2643782e30 | ||
|
|
3eb72e5c1d | ||
|
|
9b65c23a7e | ||
|
|
b43a8e48c6 | ||
|
|
1955c1d67b | ||
|
|
3f92ed9d29 | ||
|
|
618369f4a1 | ||
|
|
2783216781 | ||
|
|
bec0f9fb23 | ||
|
|
97a03e7fc8 | ||
|
|
8d6e8269b7 | ||
|
|
9ce2c6c517 | ||
|
|
2ad8bdbc65 | ||
|
|
a83c9b40d5 | ||
|
|
340fab1375 | ||
|
|
3ec338307f | ||
|
|
27acd3387a | ||
|
|
d14ef431a7 | ||
|
|
9bffeb65af | ||
|
|
f4806da653 | ||
|
|
e2700b2bbd | ||
|
|
fc81a3fb12 | ||
|
|
2203cfabea | ||
|
|
f4050306d6 | ||
|
|
2d960a477f | ||
|
|
8837b8ea79 | ||
|
|
3dfb214f73 | ||
|
|
18d7262608 | ||
|
|
09b879ee73 | ||
|
|
aaa668c963 | ||
|
|
edb877f4bc | ||
|
|
eb369caefb | ||
|
|
b9567eabd7 | ||
|
|
13bbf67091 | ||
|
|
457a4c73f0 | ||
|
|
ce37688b5b | ||
|
|
4e2c90f4af | ||
|
|
513dd8a319 | ||
|
|
71c5043832 | ||
|
|
64b6f15e95 | ||
|
|
35022f5f09 | ||
|
|
0d44014c16 | ||
|
|
1b9e9f48fa | ||
|
|
05fb5aa27b | ||
|
|
3b645b72a3 | ||
|
|
fe770b5c3a | ||
|
|
1eaf885f50 | ||
|
|
a187aa508c | ||
|
|
aa4bfa2a78 | ||
|
|
9011b8a139 | ||
|
|
59c774353a | ||
|
|
b458d504af | ||
|
|
f83e7bfcd9 | ||
|
|
4d2e26ce4b | ||
|
|
817fdc1f36 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,24 +6,6 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,6 +66,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -118,6 +118,6 @@ jobs:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
14
.github/workflows/pr-python-connector-tests.yml
vendored
14
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,7 +26,19 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
2
.vscode/env_template.txt
vendored
2
.vscode/env_template.txt
vendored
@@ -5,6 +5,8 @@
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
# Skip warm up for dev
|
||||
SKIP_WARM_UP=True
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
|
||||
15
.vscode/launch.template.jsonc
vendored
15
.vscode/launch.template.jsonc
vendored
@@ -355,5 +355,20 @@
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -12,6 +12,10 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
@@ -23,8 +27,8 @@ If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
@@ -42,7 +46,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
|
||||
@@ -123,7 +127,47 @@ Once the above is done, navigate to `onyx/web` run:
|
||||
npm i
|
||||
```
|
||||
|
||||
#### Docker containers for external software
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
We highly recommend using VSCode debugger for development.
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
@@ -135,7 +179,7 @@ docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
@@ -223,35 +267,6 @@ If you want to make changes to Onyx and run those changes in Docker, you can als
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
#### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
### Release Process
|
||||
|
||||
|
||||
29
CONTRIBUTING_VSCODE.md
Normal file
29
CONTRIBUTING_VSCODE.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# VSCode Debugging Setup
|
||||
|
||||
This guide explains how to set up and use VSCode's debugging capabilities with this project.
|
||||
|
||||
## Initial Setup
|
||||
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/.env.template` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
|
||||
## Using the Debugger
|
||||
|
||||
Before starting, make sure the Docker Daemon is running.
|
||||
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
- Python debugging is configured with debugpy
|
||||
- Environment variables are loaded from `.vscode/.env`
|
||||
- Console output is organized in the integrated terminal with labeled tabs
|
||||
18
README.md
18
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -9,3 +9,4 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
region = AWS_REGION_NAME
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""add chunk count to document
|
||||
|
||||
Revision ID: 2955778aa44c
|
||||
Revises: c0aab6edb6dd
|
||||
Create Date: 2025-01-04 11:39:43.268612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2955778aa44c"
|
||||
down_revision = "c0aab6edb6dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "chunk_count")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add composite index for index attempt time updated
|
||||
|
||||
Revision ID: 369644546676
|
||||
Revises: 2955778aa44c
|
||||
Create Date: 2025-01-08 15:38:17.224380
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "369644546676"
|
||||
down_revision = "2955778aa44c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_index_attempt_ccpair_search_settings_time_updated",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
text("time_updated DESC"),
|
||||
],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_index_attempt_ccpair_search_settings_time_updated",
|
||||
table_name="index_attempt",
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""mapping for anonymous user path
|
||||
|
||||
Revision ID: a4f6ee863c47
|
||||
Revises: 14a83a331951
|
||||
Create Date: 2025-01-04 14:16:58.697451
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a4f6ee863c47"
|
||||
down_revision = "14a83a331951"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tenant_anonymous_user_path",
|
||||
sa.Column("tenant_id", sa.String(), primary_key=True, nullable=False),
|
||||
sa.Column("anonymous_user_path", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
sa.UniqueConstraint("anonymous_user_path"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tenant_anonymous_user_path")
|
||||
@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -52,9 +57,20 @@ def _get_access_for_documents(
|
||||
)
|
||||
doc_id_map = {doc.id: doc for doc in documents}
|
||||
|
||||
# Get all sources in one batch
|
||||
doc_id_to_source_map = get_document_sources(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
set(document.external_user_emails)
|
||||
@@ -70,7 +86,11 @@ def _get_access_for_documents(
|
||||
|
||||
# If the document is determined to be "public" externally (through a SYNC connector)
|
||||
# then it's given the same access level as if it were marked public within Onyx
|
||||
is_public_anywhere = document.is_public or non_ee_access.is_public
|
||||
# If its censored, then it's public anywhere during the search and then permissions are
|
||||
# applied after the search
|
||||
is_public_anywhere = (
|
||||
document.is_public or non_ee_access.is_public or is_only_censored
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.seeding import get_seed_config
|
||||
from ee.onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -118,3 +121,17 @@ async def current_cloud_superuser(
|
||||
detail="Access denied. User must be a cloud superuser to perform this action.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
# Token does not expire
|
||||
"iat": datetime.utcnow(), # Issued at time
|
||||
}
|
||||
|
||||
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_anonymous_user_jwt_token(token: str) -> dict:
|
||||
return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"])
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.user_group import delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -46,11 +47,20 @@ def monitor_usergroup_taskset(
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -15,6 +15,12 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# This is a boolean that determines if anonymous access is public
|
||||
# Default behavior is to not make the page public and instead add a group
|
||||
# that contains all the users that we found in Confluence
|
||||
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -55,3 +61,5 @@ POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Date
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
def fetch_query_analytics(
|
||||
@@ -234,3 +238,121 @@ def fetch_persona_unique_users(
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_message_analytics(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily message counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(ChatMessage.id),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily unique user counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(func.distinct(ChatSession.user_id)),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users_total(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the total number of distinct users who have sent or received messages from
|
||||
the specified assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(func.count(func.distinct(ChatSession.user_id)))
|
||||
.select_from(ChatMessage)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(query).scalar()
|
||||
return result if result else 0
|
||||
|
||||
|
||||
# Users can view assistant stats if they created the persona,
|
||||
# or if they are an admin
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
stmt = select(Persona).where(
|
||||
and_(Persona.id == assistant_id, Persona.user_id == user.id)
|
||||
)
|
||||
|
||||
persona = db_session.execute(stmt).scalar_one_or_none()
|
||||
return persona is not None
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
group_ids: list[str],
|
||||
) -> list[User__ExternalUserGroupId]:
|
||||
user = get_user_by_email(db_session=db_session, email=user_email)
|
||||
if user is None:
|
||||
return []
|
||||
user_id = user.id
|
||||
user_ext_groups = db_session.scalars(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
|
||||
)
|
||||
).all()
|
||||
return list(user_ext_groups)
|
||||
|
||||
@@ -24,6 +24,7 @@ def _add_user_filters(
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return stmt
|
||||
|
||||
stmt = stmt.distinct()
|
||||
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=set_curator_request.user_id,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# This is a group that we use to store all the users that we found in Confluence
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
@@ -4,6 +4,8 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
@@ -22,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
@@ -31,14 +35,32 @@ def _get_server_space_permissions(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
@@ -47,14 +69,17 @@ def _get_server_space_permissions(
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
# TODO: Check if the space is publicly accessible
|
||||
# Currently, we assume the space is not public
|
||||
# We need to check if anonymous access is turned on for the site and space
|
||||
# This information is paywalled so it remains unimplemented
|
||||
is_public=False,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +159,7 @@ def _get_space_permissions(
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> ExternalAccess | None:
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
@@ -177,21 +202,57 @@ def _extract_read_access_restrictions(
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return set(read_access_user_emails), set(read_access_group_names)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page by taking the intersection
|
||||
of the page's restrictions and the restrictions of all the ancestors
|
||||
of the page.
|
||||
If the page/ancestor has no restrictions, then it is ignored (no intersection).
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
found_user_emails, found_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
for ancestor in ancestors:
|
||||
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if not ancestor_user_emails and not ancestor_group_names:
|
||||
# This ancestor has no restrictions, so it has no effect on
|
||||
# the page's restrictions, so we ignore it
|
||||
continue
|
||||
|
||||
found_user_emails.intersection_update(ancestor_user_emails)
|
||||
found_group_names.intersection_update(ancestor_group_names)
|
||||
|
||||
# If there are no restrictions found, then the page
|
||||
# inherits the space's restrictions so return None
|
||||
is_space_public = read_access_user_emails == [] and read_access_group_names == []
|
||||
if is_space_public:
|
||||
if not found_user_emails and not found_group_names:
|
||||
return None
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_group_names),
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
# there is no way for a page to be individually public if the space isn't public
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
@@ -208,11 +269,11 @@ def _fetch_all_page_restrictions_for_space(
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
restrictions = _extract_read_access_restrictions(
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
@@ -301,7 +362,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
return _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def _build_group_member_email_map(
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
@@ -53,6 +54,7 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
@@ -60,5 +62,15 @@ def confluence_group_sync(
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
all_found_emails.update(group_member_emails)
|
||||
|
||||
# This is so that when we find a public confleunce server page, we can
|
||||
# give access to all users only in if they have an email in Confluence
|
||||
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
|
||||
all_found_group = ExternalUserGroup(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
}
|
||||
|
||||
|
||||
# NOTE: This is only called if ee is enabled.
|
||||
def _post_query_chunk_censoring(
|
||||
chunks: list[InferenceChunk],
|
||||
user: User | None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function checks all chunks to see if they need to be sent to a censoring
|
||||
function. If they do, it sends them to the censoring function and returns the
|
||||
censored chunks. If they don't, it returns the original chunks.
|
||||
"""
|
||||
if user is None:
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
for chunk in chunks:
|
||||
# Separate out chunks that require permission post-processing by source
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to censor chunks for source {source} so throwing out all"
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
return chunks_to_keep
|
||||
@@ -0,0 +1,226 @@
|
||||
import time
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_any_salesforce_client_for_doc_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_salesforce_user_id_from_email,
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Types
|
||||
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
|
||||
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
|
||||
|
||||
|
||||
# NOTE: Used for testing timing
|
||||
def _get_dummy_object_access_map(
|
||||
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
|
||||
) -> dict[str, bool]:
|
||||
time.sleep(0.15)
|
||||
# return {object_id: True for object_id in object_ids}
|
||||
import random
|
||||
|
||||
return {object_id: random.choice([True, False]) for object_id in object_ids}
|
||||
|
||||
|
||||
def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids: set[str],
|
||||
user_email: str,
|
||||
chunks: list[InferenceChunk],
|
||||
) -> dict[str, bool] | None:
|
||||
"""
|
||||
This function wraps the salesforce call as we may want to change how this
|
||||
is done in the future. (E.g. replace it with the above function)
|
||||
"""
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries for this source are essentially instant
|
||||
first_doc_id = chunks[0].document_id
|
||||
with get_session_context_manager() as db_session:
|
||||
salesforce_client = get_any_salesforce_client_for_doc_id(
|
||||
db_session, first_doc_id
|
||||
)
|
||||
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries by the same user are essentially instant
|
||||
start_time = time.time()
|
||||
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
# so it takes 0.1-0.2 seconds total
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
def _extract_salesforce_object_id_from_url(url: str) -> str:
|
||||
return url.split("/")[-1]
|
||||
|
||||
|
||||
def _get_object_ranges_for_chunk(
|
||||
chunk: InferenceChunk,
|
||||
) -> dict[str, list[ContentRange]]:
|
||||
"""
|
||||
Given a chunk, return a dictionary of salesforce object ids and the content ranges
|
||||
for that object id in the current chunk
|
||||
"""
|
||||
if chunk.source_links is None:
|
||||
return {}
|
||||
|
||||
object_ranges: dict[str, list[ContentRange]] = {}
|
||||
end_index = None
|
||||
descending_source_links = sorted(
|
||||
chunk.source_links.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
for start_index, url in descending_source_links:
|
||||
object_id = _extract_salesforce_object_id_from_url(url)
|
||||
if object_id not in object_ranges:
|
||||
object_ranges[object_id] = []
|
||||
object_ranges[object_id].append((start_index, end_index))
|
||||
end_index = start_index
|
||||
return object_ranges
|
||||
|
||||
|
||||
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
|
||||
"""
|
||||
Create a copy of the unfiltered chunk where potentially sensitive content is removed
|
||||
to be added later if the user has access to each of the sub-objects
|
||||
"""
|
||||
empty_censored_chunk = InferenceChunk(
|
||||
**uncensored_chunk.model_dump(),
|
||||
)
|
||||
empty_censored_chunk.content = ""
|
||||
empty_censored_chunk.blurb = ""
|
||||
empty_censored_chunk.source_links = {}
|
||||
return empty_censored_chunk
|
||||
|
||||
|
||||
def _update_censored_chunk(
|
||||
censored_chunk: InferenceChunk,
|
||||
uncensored_chunk: InferenceChunk,
|
||||
content_range: ContentRange,
|
||||
) -> InferenceChunk:
|
||||
"""
|
||||
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
|
||||
"""
|
||||
start_index, end_index = content_range
|
||||
|
||||
# Update the content of the filtered chunk
|
||||
permitted_content = uncensored_chunk.content[start_index:end_index]
|
||||
permitted_section_start_index = len(censored_chunk.content)
|
||||
censored_chunk.content = permitted_content + censored_chunk.content
|
||||
|
||||
# Update the source links of the filtered chunk
|
||||
if uncensored_chunk.source_links is not None:
|
||||
if censored_chunk.source_links is None:
|
||||
censored_chunk.source_links = {}
|
||||
link_content = uncensored_chunk.source_links[start_index]
|
||||
censored_chunk.source_links[permitted_section_start_index] = link_content
|
||||
|
||||
# Update the blurb of the filtered chunk
|
||||
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
|
||||
|
||||
return censored_chunk
|
||||
|
||||
|
||||
# TODO: Generalize this to other sources
|
||||
def censor_salesforce_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
user_email: str,
|
||||
# This is so we can provide a mock access map for testing
|
||||
access_map: dict[str, bool] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
|
||||
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
|
||||
|
||||
# (doc_id, chunk_id) -> chunk
|
||||
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
|
||||
# keep track of all object ids that we have seen to make it easier to get
|
||||
# the access for these object ids
|
||||
object_ids: set[str] = set()
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_key = (chunk.document_id, chunk.chunk_id)
|
||||
# create a dictionary to quickly look up the unfiltered chunk
|
||||
uncensored_chunks[chunk_key] = chunk
|
||||
|
||||
# for each chunk, get a dictionary of object ids and the content ranges
|
||||
# for that object id in the current chunk
|
||||
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
|
||||
for object_id, ranges in object_ranges_for_chunk.items():
|
||||
object_ids.add(object_id)
|
||||
for start_index, end_index in ranges:
|
||||
object_to_content_map.setdefault(object_id, []).append(
|
||||
(chunk_key, (start_index, end_index))
|
||||
)
|
||||
|
||||
# This is so we can provide a mock access map for testing
|
||||
if access_map is None:
|
||||
access_map = _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids=object_ids,
|
||||
user_email=user_email,
|
||||
chunks=chunks,
|
||||
)
|
||||
if access_map is None:
|
||||
# If the user is not found in Salesforce, access_map will be None
|
||||
# so we should just return an empty list because no chunks will be
|
||||
# censored
|
||||
return []
|
||||
|
||||
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
for object_id, content_list in object_to_content_map.items():
|
||||
# if the user does not have access to the object, or the object is not in the
|
||||
# access_map, do not include its content in the filtered chunks
|
||||
if not access_map.get(object_id, False):
|
||||
continue
|
||||
|
||||
# if we got this far, the user has access to the object so we can create or update
|
||||
# the filtered chunk(s) for this object
|
||||
# NOTE: we only create a censored chunk if the user has access to some
|
||||
# part of the chunk
|
||||
for chunk_key, content_range in content_list:
|
||||
if chunk_key not in censored_chunks:
|
||||
censored_chunks[chunk_key] = _create_empty_censored_chunk(
|
||||
uncensored_chunks[chunk_key]
|
||||
)
|
||||
|
||||
uncensored_chunk = uncensored_chunks[chunk_key]
|
||||
censored_chunk = _update_censored_chunk(
|
||||
censored_chunk=censored_chunks[chunk_key],
|
||||
uncensored_chunk=uncensored_chunk,
|
||||
content_range=content_range,
|
||||
)
|
||||
censored_chunks[chunk_key] = censored_chunk
|
||||
|
||||
return list(censored_chunks.values())
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_objects_access_for_user_email(
|
||||
object_ids: set[str], user_email: str
|
||||
) -> dict[str, bool]:
|
||||
with get_session_context_manager() as db_session:
|
||||
external_groups = fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session=db_session,
|
||||
user_email=user_email,
|
||||
# Maybe make a function that adds a salesforce prefix to the group ids
|
||||
group_ids=list(object_ids),
|
||||
)
|
||||
external_group_ids = {group.external_user_group_id for group in external_groups}
|
||||
return {group_id: group_id in external_group_ids for group_id in object_ids}
|
||||
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from simple_salesforce import Salesforce
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_cc_pairs_for_document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
|
||||
|
||||
|
||||
def get_any_salesforce_client_for_doc_id(
|
||||
db_session: Session, doc_id: str
|
||||
) -> Salesforce:
|
||||
"""
|
||||
We create a salesforce client for the first cc_pair for the first doc_id where
|
||||
salesforce censoring is enabled. After that we just cache and reuse the same
|
||||
client for all queries.
|
||||
|
||||
We do this to reduce the number of postgres queries we make at query time.
|
||||
|
||||
This may be problematic if they are using multiple cc_pairs for salesforce.
|
||||
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
|
||||
but only one has the permissions to access the permissions needed for the query.
|
||||
"""
|
||||
global _ANY_SALESFORCE_CLIENT
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
return _ANY_SALESFORCE_CLIENT
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
# If we don't know their user_id, we don't store anything in the cache.
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_salesforce_user_id_from_email(
|
||||
sf_client: Salesforce,
|
||||
user_email: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
We cache this so we don't have to query Salesforce for every query and salesforce
|
||||
user IDs never change.
|
||||
Memory usage is fine because we just store 2 small strings per user.
|
||||
|
||||
If the email is not in the cache, we check the local salesforce database for the info.
|
||||
If the user is not found in the local salesforce database, we query Salesforce.
|
||||
Whatever we get back from Salesforce is added to the database.
|
||||
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
|
||||
we don't query Salesforce again (which is slow) but we still check the local salesforce
|
||||
database every query until a user id is found. This is acceptable because the query time
|
||||
is quite fast.
|
||||
If a user_id is created in Salesforce, it will be added to the local salesforce database
|
||||
next time the connector is run. Then that value will be found in this function and cached.
|
||||
|
||||
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
|
||||
salesforce database. (Around 0.1-0.3 seconds)
|
||||
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
|
||||
"""
|
||||
global _CACHED_SF_EMAIL_TO_ID_MAP
|
||||
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
|
||||
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
|
||||
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
|
||||
|
||||
db_exists = True
|
||||
try:
|
||||
# Check if the user is already in the database
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception:
|
||||
init_db()
|
||||
try:
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if user is in database: {e}")
|
||||
user_id = None
|
||||
db_exists = False
|
||||
|
||||
# If no entry is found in the database (indicated by user_id being None)...
|
||||
if user_id is None:
|
||||
# ...query Salesforce and store the result in the database
|
||||
user_id = _query_salesforce_user_id(sf_client, user_email)
|
||||
if db_exists:
|
||||
update_email_to_id_table(user_email, user_id)
|
||||
return user_id
|
||||
elif user_id is None:
|
||||
return None
|
||||
elif user_id == NULL_ID_STRING:
|
||||
return None
|
||||
# If the found user_id is real, cache it
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
|
||||
return user_id
|
||||
|
||||
|
||||
_MAX_RECORD_IDS_PER_QUERY = 200
|
||||
|
||||
|
||||
def get_objects_access_for_user_id(
|
||||
salesforce_client: Salesforce,
|
||||
user_id: str,
|
||||
record_ids: list[str],
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Salesforce has a limit of 200 record ids per query. So we just truncate
|
||||
the list of record ids to 200. We only ever retrieve 50 chunks at a time
|
||||
so this should be fine (unlikely that we retrieve all 50 chunks contain
|
||||
4 unique objects).
|
||||
If we decide this isn't acceptable we can use multiple queries but they
|
||||
should be in parallel so query time doesn't get too long.
|
||||
"""
|
||||
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
|
||||
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
|
||||
access_query = f"""
|
||||
SELECT RecordId, HasReadAccess
|
||||
FROM UserRecordAccess
|
||||
WHERE RecordId IN ({record_ids_str})
|
||||
AND UserId = '{user_id}'
|
||||
"""
|
||||
result = salesforce_client.query_all(access_query)
|
||||
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
|
||||
|
||||
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
|
||||
"""
|
||||
Uses a document id to get the cc_pair that indexed that document and uses the credentials
|
||||
for that cc_pair to create a Salesforce client.
|
||||
Problems:
|
||||
- There may be multiple cc_pairs for a document, and we don't know which one to use.
|
||||
- right now we just use the first one
|
||||
- Building a new Salesforce client for each document is slow.
|
||||
- Memory usage could be an issue as we build these dictionaries.
|
||||
"""
|
||||
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
|
||||
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]
|
||||
@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -97,19 +96,20 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.analytics import fetch_assistant_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
|
||||
from ee.onyx.db.analytics import fetch_onyxbot_analytics
|
||||
from ee.onyx.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -191,3 +198,74 @@ def get_persona_unique_users(
|
||||
)
|
||||
)
|
||||
return unique_user_counts
|
||||
|
||||
|
||||
class AssistantDailyUsageResponse(BaseModel):
|
||||
date: datetime.date
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
class AssistantStatsResponse(BaseModel):
|
||||
daily_stats: List[AssistantDailyUsageResponse]
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
@router.get("/assistant/{assistant_id}/stats")
|
||||
def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
Returns daily message and unique user counts for a user's assistant,
|
||||
along with the overall total messages and total distinct users.
|
||||
"""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
|
||||
if not user_can_view_assistant_stats(db_session, user, assistant_id):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not allowed to access this assistant's stats."
|
||||
)
|
||||
|
||||
# Pull daily usage from the DB calls
|
||||
messages_data = fetch_assistant_message_analytics(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
unique_users_data = fetch_assistant_unique_users(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
# Map each day => (messages, unique_users).
|
||||
daily_messages_map = {date: count for count, date in messages_data}
|
||||
daily_unique_users_map = {date: count for count, date in unique_users_data}
|
||||
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
|
||||
|
||||
# Merge both sets of metrics by date
|
||||
daily_results: list[AssistantDailyUsageResponse] = []
|
||||
for date in sorted(all_dates):
|
||||
daily_results.append(
|
||||
AssistantDailyUsageResponse(
|
||||
date=date,
|
||||
total_messages=daily_messages_map.get(date, 0),
|
||||
total_unique_users=daily_unique_users_map.get(date, 0),
|
||||
)
|
||||
)
|
||||
|
||||
# Now pull a single total distinct user count across the entire time range
|
||||
total_msgs = sum(d.total_messages for d in daily_results)
|
||||
total_users = fetch_assistant_unique_users_total(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
return AssistantStatsResponse(
|
||||
daily_stats=daily_results,
|
||||
total_messages=total_msgs,
|
||||
total_unique_users=total_users,
|
||||
)
|
||||
|
||||
@@ -2,15 +2,16 @@ import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -22,11 +23,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
tenant_id = (
|
||||
_get_tenant_id_from_request(request, logger)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
@@ -35,27 +36,46 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
raise
|
||||
|
||||
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
async def _get_tenant_id_from_request(
|
||||
request: Request, logger: logging.LoggerAdapter
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
@@ -67,9 +87,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -62,14 +82,7 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -77,10 +90,14 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -102,82 +119,151 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -192,8 +278,11 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -203,6 +292,11 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -210,8 +304,6 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -223,6 +315,7 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -13,9 +13,8 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
from ee.onyx.db.usage_export import write_usage_report
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
|
||||
from onyx.auth.schemas import UserStatus
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.users import list_users
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -84,15 +83,15 @@ def generate_user_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["user_id", "status"])
|
||||
csvwriter.writerow(["user_id", "is_active"])
|
||||
|
||||
users = list_users(db_session)
|
||||
users = get_all_users(db_session)
|
||||
for user in users:
|
||||
user_skeleton = UserSkeleton(
|
||||
user_id=str(user.id),
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
|
||||
|
||||
temp_file.seek(0)
|
||||
file_store.save_file(
|
||||
|
||||
@@ -4,8 +4,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserStatus
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -22,7 +20,7 @@ class ChatMessageSkeleton(BaseModel):
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
user_id: str
|
||||
status: UserStatus
|
||||
is_active: bool
|
||||
|
||||
|
||||
class UsageReportMetadata(BaseModel):
|
||||
|
||||
59
backend/ee/onyx/server/tenants/anonymous_user_path.py
Normal file
59
backend/ee/onyx/server/tenants/anonymous_user_path.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import TenantAnonymousUserPath
|
||||
|
||||
|
||||
def get_anonymous_user_path(tenant_id: str, db_session: Session) -> str | None:
|
||||
result = db_session.execute(
|
||||
select(TenantAnonymousUserPath).where(
|
||||
TenantAnonymousUserPath.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result_scalar = result.scalar_one_or_none()
|
||||
if result_scalar:
|
||||
return result_scalar.anonymous_user_path
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def modify_anonymous_user_path(
|
||||
tenant_id: str, anonymous_user_path: str, db_session: Session
|
||||
) -> None:
|
||||
# Enforce lowercase path at DB operation level
|
||||
anonymous_user_path = anonymous_user_path.lower()
|
||||
|
||||
existing_entry = (
|
||||
db_session.query(TenantAnonymousUserPath).filter_by(tenant_id=tenant_id).first()
|
||||
)
|
||||
|
||||
if existing_entry:
|
||||
existing_entry.anonymous_user_path = anonymous_user_path
|
||||
|
||||
else:
|
||||
new_entry = TenantAnonymousUserPath(
|
||||
tenant_id=tenant_id, anonymous_user_path=anonymous_user_path
|
||||
)
|
||||
db_session.add(new_entry)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path: str, db_session: Session
|
||||
) -> str | None:
|
||||
result = db_session.execute(
|
||||
select(TenantAnonymousUserPath).where(
|
||||
TenantAnonymousUserPath.anonymous_user_path == anonymous_user_path
|
||||
)
|
||||
)
|
||||
result_scalar = result.scalar_one_or_none()
|
||||
if result_scalar:
|
||||
return result_scalar.tenant_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def validate_anonymous_user_path(path: str) -> None:
|
||||
if not path or "/" in path or not path.replace("-", "").isalnum():
|
||||
raise ValueError("Invalid path. Use only letters, numbers, and hyphens.")
|
||||
@@ -3,35 +3,124 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
@@ -103,7 +192,7 @@ async def impersonate_user(
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
@@ -114,3 +203,48 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
|
||||
@@ -39,3 +39,12 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
@@ -15,6 +15,7 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -185,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -320,3 +322,26 @@ async def submit_to_hubspot(
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,6 +91,7 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
@@ -10,6 +10,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
@@ -24,15 +25,10 @@ posthog = Posthog(
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
|
||||
print("API KEY", POSTHOG_API_KEY)
|
||||
print("HOST", POSTHOG_HOST)
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
try:
|
||||
print(type(distinct_id))
|
||||
print(type(event))
|
||||
print(type(properties))
|
||||
response = posthog.capture(distinct_id, event, properties)
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
print(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing Posthog event: {e}")
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -320,8 +321,6 @@ async def embed_text(
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
@@ -330,8 +329,17 @@ async def embed_text(
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
total_chars = 0
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@@ -363,8 +371,16 @@ async def embed_text(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@@ -382,13 +398,17 @@ async def embed_text(
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -440,7 +460,8 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
elif not all(embed_request.texts):
|
||||
|
||||
if not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
@@ -471,9 +492,12 @@ async def process_embed_request(
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
logger.exception(
|
||||
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
||||
@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
83
backend/onyx/auth/email_utils.py
Normal file
83
backend/onyx/auth/email_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -30,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -33,12 +33,6 @@ class UserRole(str, Enum):
|
||||
]
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
@@ -49,4 +43,7 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import smtplib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
@@ -32,10 +31,8 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -49,23 +46,22 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -74,22 +70,23 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -103,6 +100,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -143,6 +145,17 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -193,30 +206,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -281,7 +270,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -405,11 +393,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -428,7 +414,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
@@ -506,7 +491,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -583,51 +576,70 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
)(email=user.email)
|
||||
|
||||
data = {
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||
) -> Optional[User]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
async def destroy_token(self, token: str, user: User) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -713,30 +725,36 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -751,6 +769,15 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -335,6 +335,10 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_lock"):
|
||||
# primary_worker_lock will not exist when MULTI_TENANT is True
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
@@ -414,11 +418,21 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
@@ -60,7 +60,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -88,12 +89,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
@@ -194,6 +195,10 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
@@ -281,5 +286,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,54 @@ import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -41,14 +42,21 @@ def _get_deletion_status(
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
return None
|
||||
if redis_connector.delete.fenced:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
|
||||
@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
@@ -13,7 +20,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -22,7 +29,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -31,7 +38,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -40,7 +47,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -49,7 +56,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -58,7 +65,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -67,7 +74,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -76,11 +83,25 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -42,11 +44,11 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -81,6 +83,8 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -14,10 +16,14 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -88,19 +94,19 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -128,6 +134,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -219,6 +227,43 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -244,6 +289,8 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
return None
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -254,8 +301,11 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
|
||||
@@ -94,19 +94,19 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -18,10 +19,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -29,6 +32,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -58,6 +62,8 @@ from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -69,14 +75,18 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -86,26 +96,54 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -167,6 +205,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
debug_tenants = {
|
||||
"tenant_i-043470d740845ec56",
|
||||
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
|
||||
}
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
@@ -175,18 +217,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
@@ -209,15 +251,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
|
||||
# gather cc_pair_ids
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
@@ -226,22 +278,58 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair search settings lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_connector_credential_pair_from_id: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_last_attempt_for_cc_pair: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair should index: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
@@ -274,6 +362,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
@@ -294,6 +391,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -301,6 +418,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session, redis_client
|
||||
)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced attempt id lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -318,24 +444,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
|
||||
# turning off for the moment to see if it helps cloud stability
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing validate fences lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
# we want to run this less frequently than the overall task
|
||||
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# task_logger.info("Validating indexing fences...")
|
||||
# validate_indexing_fences(
|
||||
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -351,9 +482,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@@ -364,56 +496,20 @@ def validate_indexing_fences(
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks: set[str] = set()
|
||||
active_indexing_tasks: set[str] = set()
|
||||
indexing_worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = celery_app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if not workers:
|
||||
raise ValueError("No workers found!")
|
||||
|
||||
for worker_name in list(workers.keys()):
|
||||
if "indexing" in worker_name:
|
||||
indexing_worker_names.append(worker_name)
|
||||
|
||||
if len(indexing_worker_names) == 0:
|
||||
raise ValueError("No indexing workers found!")
|
||||
|
||||
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
|
||||
|
||||
# NOTE: each dict entry is a map of worker name to a list of tasks
|
||||
# we want sets for reserved task and active task id's to optimize
|
||||
# subsequent validation lookups
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
|
||||
if reserved_tasks is None:
|
||||
raise ValueError("inspect_indexing.reserved() returned None!")
|
||||
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_indexing_tasks.add(task["id"])
|
||||
|
||||
# get the list of active tasks
|
||||
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
|
||||
if active_tasks is None:
|
||||
raise ValueError("inspect_indexing.active() returned None!")
|
||||
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_indexing_tasks.add(task["id"])
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
active_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
@@ -424,7 +520,6 @@ def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
active_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -434,11 +529,15 @@ def validate_indexing_fence(
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. Active signal is renewed with a 5 minute TTL
|
||||
1.1 When the fence is created
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved or active list for a worker
|
||||
2. The TTL allows us to get through the transitions on fence startup
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
@@ -466,6 +565,8 @@ def validate_indexing_fence(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
@@ -501,24 +602,24 @@ def validate_indexing_fence(
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in active_tasks:
|
||||
# the celery task is active (aka currently executing)
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
|
||||
# due to gaps in our ability to check states during transitions
|
||||
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
@@ -783,7 +884,6 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -795,6 +895,58 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
@@ -821,75 +973,33 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
@@ -918,7 +1028,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -926,7 +1036,14 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
@@ -998,7 +1115,17 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -1039,7 +1166,9 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1075,17 +1204,22 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
queue_delay_datetime = payload.started - payload.submitted
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"Indexing spawned task running entrypoint: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"queue_delay={queue_delay_datetime.total_seconds()}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
@@ -1108,8 +1242,19 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
|
||||
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(default_provider.model_names or [])
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
default_provider.model_names = available_models
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
@@ -81,19 +81,19 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
@@ -127,6 +127,8 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -28,13 +28,35 @@ class RetryDocumentIndex:
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
return self.index.delete_single(doc_id)
|
||||
def delete_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
|
||||
return self.index.update_single(doc_id, fields)
|
||||
def update_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.document import fetch_chunk_count_for_document
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import get_document_connector_count
|
||||
from onyx.db.document import mark_document_as_modified
|
||||
@@ -80,7 +81,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = retry_index.delete_single(document_id)
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
chunks_affected = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
@@ -110,7 +117,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields=fields)
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
@@ -20,10 +22,12 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -67,6 +71,8 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -75,6 +81,7 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -87,7 +94,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
@@ -99,17 +106,18 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, db_session, r, lock_beat, tenant_id
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
@@ -121,6 +129,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -129,6 +138,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -148,6 +159,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -162,14 +174,21 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_vespa_sync_task - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
max_tasks: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -200,11 +219,16 @@ def try_generate_stale_document_sync_tasks(
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
tasks_remaining = max_tasks
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rc.generate_tasks(
|
||||
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if result is None:
|
||||
continue
|
||||
@@ -218,10 +242,19 @@ def try_generate_stale_document_sync_tasks(
|
||||
)
|
||||
|
||||
total_tasks_generated += result[0]
|
||||
tasks_remaining -= result[0]
|
||||
if tasks_remaining <= 0:
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
if tasks_remaining <= 0:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks reached the task generation limit: "
|
||||
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
@@ -260,7 +293,9 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rds.generate_tasks(
|
||||
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -315,7 +350,9 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rug.generate_tasks(
|
||||
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -636,15 +673,23 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -715,18 +760,21 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
@@ -736,7 +784,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
timings: dict[str, Any] = {}
|
||||
timings["start"] = time_start
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -744,55 +798,94 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return False
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
phase_start = time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 10) == 10:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
timings["queues"] = time.monotonic() - phase_start
|
||||
timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
timings["connector_deletion_ttl"] = r.ttl(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK
|
||||
)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
@@ -800,29 +893,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -830,6 +939,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id} "
|
||||
f"timings={timings}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
@@ -876,12 +992,26 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields)
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
# this code checks for and removes a per document sync key that is
|
||||
# used to block out the same doc from continualy resyncing
|
||||
# a quick hack that is only needed for production issues
|
||||
# redis_syncing_key = RedisConnectorCredentialPair.make_redis_syncing_key(
|
||||
# document_id
|
||||
# )
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
@@ -90,6 +91,35 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
@@ -238,7 +268,9 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -258,15 +290,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -276,7 +308,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
|
||||
@@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -206,9 +208,9 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
) or ([], [])
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -224,9 +226,9 @@ class Answer:
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
||||
@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
|
||||
@@ -21,20 +21,19 @@ class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.citation_order: list[int] = [] # order of citations in the LLM output
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
@@ -93,29 +92,31 @@ class CitationProcessor:
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
@@ -134,8 +135,8 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -151,13 +152,13 @@ class CitationProcessor:
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
@@ -167,7 +168,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -176,7 +176,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -17,6 +17,7 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
@@ -54,10 +55,17 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@@ -92,6 +100,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -145,7 +154,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -184,6 +193,27 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
@@ -251,6 +281,11 @@ try:
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 1024
|
||||
|
||||
DB_YIELD_PER_DEFAULT = 64
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -347,12 +382,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -503,6 +543,9 @@ try:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -542,7 +585,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
|
||||
@@ -36,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -74,13 +76,19 @@ KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -134,9 +142,11 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
DISCORD = "discord"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -240,6 +250,7 @@ class OnyxCeleryQueues:
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -273,6 +284,7 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -294,6 +306,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from pyairtable.api.types import RecordDict
|
||||
from pyairtable.models.schema import TableSchema
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def _process_record(
|
||||
self,
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
"""
|
||||
table_id = table_schema.id
|
||||
table_name = table_schema.name
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
@@ -52,10 +52,29 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
"ancestors.restrictions.read.restrictions.user",
|
||||
"ancestors.restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -64,7 +83,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
index_recursively: bool = False,
|
||||
cql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -82,23 +101,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
# if nothing is provided, we will fetch all pages
|
||||
cql_page_query = "type=page"
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
the order shouldn't matter
|
||||
However, we use elif to ensure that only of the following is enforced
|
||||
"""
|
||||
base_cql_page_query = "type=page"
|
||||
if cql_query:
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
base_cql_page_query = cql_query
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
|
||||
else:
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
base_cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
uri_safe_space = quote(space)
|
||||
base_cql_page_query += f" and space='{uri_safe_space}'"
|
||||
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
self.base_cql_page_query = base_cql_page_query
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
@@ -126,6 +147,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
@@ -205,11 +253,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
page_query = self._construct_page_query(start, end)
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -228,11 +280,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
@@ -248,17 +299,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
def poll_source(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -269,7 +315,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -279,9 +325,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
"ancestors": page_ancestors or [],
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -294,10 +342,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
|
||||
@@ -121,6 +121,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
@@ -186,27 +187,67 @@ class OnyxConfluence(Confluence):
|
||||
url_suffix += f"{connection_char}limit={limit}"
|
||||
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
try:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
next_response = self.get(url_suffix)
|
||||
raw_response = self.get(
|
||||
path=url_suffix,
|
||||
advanced_mode=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
|
||||
try:
|
||||
raw_response.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
# If the problematic expansion is in the url, replace it
|
||||
# with the replacement expansion and try again
|
||||
# If that fails, raise the error
|
||||
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
if _PROBLEMATIC_EXPANSIONS in url_suffix:
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
)
|
||||
continue
|
||||
if (
|
||||
raw_response.status_code == 500
|
||||
and limit > _MINIMUM_PAGINATION_LIMIT
|
||||
):
|
||||
new_limit = limit // 2
|
||||
logger.warning(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
f"Reducing limit from {limit} to {new_limit} and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
f"limit={limit}", f"limit={new_limit}"
|
||||
)
|
||||
limit = new_limit
|
||||
continue
|
||||
|
||||
logger.exception(
|
||||
f"Error in confluence call to {url_suffix} \n"
|
||||
f"Raw Response Text: {raw_response.text} \n"
|
||||
f"Full Response: {raw_response.__dict__} \n"
|
||||
f"Error: {e} \n"
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
raise e
|
||||
|
||||
try:
|
||||
next_response = raw_response.json()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to parse response as JSON. Response: {raw_response.__dict__}"
|
||||
)
|
||||
continue
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
@@ -313,6 +354,33 @@ class OnyxConfluence(Confluence):
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def get_all_space_permissions_server(
|
||||
self,
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "getSpacePermissionSets",
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
response = self.post(url, data=data)
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
f"No jsonrpc response for space permissions for space {space_key}"
|
||||
f"\nResponse: {response}"
|
||||
)
|
||||
|
||||
return response.get("result", [])
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
|
||||
@@ -32,11 +32,13 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.text_processing import is_valid_email
|
||||
@@ -71,3 +72,10 @@ def process_in_batches(
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
|
||||
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
||||
321
backend/onyx/connectors/discord/connector.py
Normal file
321
backend/onyx/connectors/discord/connector.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[Section],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
Sections are collected before calling this function because it relies on async
|
||||
calls to fetch the thread history if there is one
|
||||
"""
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
# Single messages dont have a title
|
||||
title = ""
|
||||
|
||||
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||
if isinstance(message.channel, Thread):
|
||||
# Threads do have a title
|
||||
title = message.channel.name
|
||||
|
||||
# If its a thread, update the metadata, title, and semantic_substring
|
||||
metadata["Thread"] = title
|
||||
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
title=title,
|
||||
sections=sections,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_filtered_channels(
|
||||
discord_client: Client,
|
||||
server_ids: list[int] | None,
|
||||
channel_names: list[str] | None,
|
||||
) -> list[TextChannel]:
|
||||
filtered_channels: list[TextChannel] = []
|
||||
|
||||
for channel in discord_client.get_all_channels():
|
||||
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||
continue
|
||||
if not isinstance(channel, TextChannel):
|
||||
continue
|
||||
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||
continue
|
||||
if channel_names and channel.name not in channel_names:
|
||||
continue
|
||||
filtered_channels.append(channel)
|
||||
|
||||
logger.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||
return filtered_channels
|
||||
|
||||
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
start_time = discord_epoch
|
||||
|
||||
async for channel_message in channel.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[Section] = [
|
||||
Section(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
async for archived_thread in channel.archived_threads():
|
||||
async for thread_message in archived_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(PollConnector, LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
server_ids: list[str] = [],
|
||||
channel_names: list[str] = [],
|
||||
# YYYY-MM-DD
|
||||
start_date: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@property
|
||||
def discord_bot_token(self) -> str:
|
||||
if self._discord_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
end = time.time()
|
||||
# 1 day
|
||||
start = end - 24 * 60 * 60 * 1
|
||||
# "1,2,3"
|
||||
server_ids: str | None = os.environ.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = os.environ.get("channel_names", None)
|
||||
|
||||
connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 1
|
||||
page = 0
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
|
||||
@@ -3,20 +3,19 @@ import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
@@ -166,6 +125,15 @@ def _process_egnyte_file(
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||
egnyte_domain: str = Field(
|
||||
title="Egnyte Domain",
|
||||
description=(
|
||||
"The domain for the Egnyte instance "
|
||||
"(e.g. 'company' for company.egnyte.com)"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
||||
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"domain": oauth_kwargs.egnyte_domain,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
@@ -248,7 +224,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def _get_files_list(
|
||||
self,
|
||||
path: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
if not self.access_token or not self.domain:
|
||||
raise ConnectorMissingCredentialError("Egnyte")
|
||||
|
||||
@@ -260,67 +236,66 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
url_encoded_path = quote(path or "")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
all_files: list[dict[str, Any]] = []
|
||||
|
||||
# Add files from current directory
|
||||
all_files.extend(data.get("files", []))
|
||||
# Yield files from current directory
|
||||
for file in data.get("files", []):
|
||||
yield file
|
||||
|
||||
# Recursively traverse folders
|
||||
for item in data.get("folders", []):
|
||||
all_files.extend(self._get_files_list(item["path"]))
|
||||
for folder in data.get("folders", []):
|
||||
yield from self._get_files_list(folder["path"])
|
||||
|
||||
return all_files
|
||||
|
||||
def _filter_files(
|
||||
def _should_index_file(
|
||||
self,
|
||||
files: list[dict[str, Any]],
|
||||
file: dict[str, Any],
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
if file["is_folder"]:
|
||||
continue
|
||||
) -> bool:
|
||||
"""Return True if file should be included based on filters."""
|
||||
if file["is_folder"]:
|
||||
return False
|
||||
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
continue
|
||||
if end_time and file_modified > end_time:
|
||||
continue
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
return False
|
||||
if end_time and file_modified > end_time:
|
||||
return False
|
||||
|
||||
filtered_files.append(file)
|
||||
|
||||
return filtered_files
|
||||
return True
|
||||
|
||||
def _process_files(
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
files = self._get_files_list(self.folder_path)
|
||||
files = self._filter_files(files, start_time, end_time)
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for file in files:
|
||||
|
||||
# Iterate through yielded files and filter them
|
||||
for file in self._get_files_list(self.folder_path):
|
||||
if not self._should_index_file(file, start_time, end_time):
|
||||
logger.debug(f"Skipping file '{file['path']}'.")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Set up request with streaming enabled
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
url_encoded_path = quote(file["path"])
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,12 +5,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
@@ -100,9 +102,11 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
"""
|
||||
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logger.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
logger.info(f"Impersonating user {user_email}")
|
||||
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - include_my_drives is true
|
||||
# - the current user's email is in the requested emails
|
||||
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
|
||||
@@ -17,6 +17,9 @@ from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import get_gmail_service
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@@ -29,6 +32,9 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
MISSING_SCOPES_ERROR_STR,
|
||||
)
|
||||
@@ -96,6 +102,7 @@ def update_credential_access_tokens(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
auth_method: GoogleOAuthAuthenticationMethod,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
@@ -119,6 +126,7 @@ def update_credential_access_tokens(
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value,
|
||||
}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
@@ -129,6 +137,7 @@ def update_credential_access_tokens(
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
primary_admin_email: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key(source=source)
|
||||
|
||||
@@ -138,10 +147,15 @@ def build_service_account_creds(
|
||||
if primary_admin_email:
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
|
||||
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.UPLOADED.value
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=source,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
class AdditionalOauthKwargs(BaseModel):
|
||||
# if overridden, all fields should be str type
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -7,16 +7,23 @@ from typing import cast
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
)
|
||||
|
||||
|
||||
class LinearConnector(LoadConnector, PollConnector):
|
||||
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.linear_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(
|
||||
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||
) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||
return (
|
||||
f"https://linear.app/oauth/authorize"
|
||||
f"?client_id={LINEAR_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(
|
||||
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
base_domain, DocumentSource.LINEAR.value
|
||||
),
|
||||
"client_id": LINEAR_CLIENT_ID,
|
||||
"client_secret": LINEAR_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url="https://api.linear.app/oauth/token",
|
||||
data=data,
|
||||
headers=headers,
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
return {
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
if "linear_api_key" in credentials:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
elif "access_token" in credentials:
|
||||
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||
else:
|
||||
# May need to handle case in the future if the OAuth flow expires
|
||||
raise ConnectorMissingCredentialError("Linear")
|
||||
|
||||
return None
|
||||
|
||||
def _process_issues(
|
||||
|
||||
@@ -101,8 +101,11 @@ class DocumentBase(BaseModel):
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
# Owner, creator, etc.
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
# Assignee, space owner, etc.
|
||||
|
||||
@@ -1,42 +1,34 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.utils import extract_dict_text
|
||||
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
||||
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,200 +36,133 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sf_client: Salesforce | None = None
|
||||
self._sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.sf_client = Salesforce(
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
@property
|
||||
def sf_client(self) -> Salesforce:
|
||||
if self._sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
)
|
||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
||||
return name
|
||||
except Exception:
|
||||
logger.warning(f"Couldnt find name for object id: {id}")
|
||||
return "Null User"
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, object_dict: dict[str, Any]
|
||||
) -> Document:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
||||
)
|
||||
]
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=extracted_primary_owners,
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = self.sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if self._is_valid_child_object(child_relationship):
|
||||
children_objects.append(
|
||||
{
|
||||
"relationship_name": child_relationship["relationshipName"],
|
||||
"object_type": child_relationship["childSObject"],
|
||||
}
|
||||
)
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
field.get("name")
|
||||
for field in object_description["fields"]
|
||||
if field.get("type", "base64") != "base64"
|
||||
]
|
||||
|
||||
return fields
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
else:
|
||||
query += query_addition
|
||||
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
|
||||
yield query
|
||||
return self._sf_client
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
init_db()
|
||||
all_object_types: set[str] = set(self.parent_object_list)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
||||
logger.debug(f"Parent object types: {self.parent_object_list}")
|
||||
|
||||
# This takes like 20 seconds
|
||||
for parent_object_type in self.parent_object_list:
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
logger.info(
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
child_types = get_all_children_of_sf_type(
|
||||
self.sf_client, parent_object_type
|
||||
)
|
||||
all_object_types.update(child_types)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_object_types.add("User")
|
||||
|
||||
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||
logger.debug(f"All object types: {all_object_types}")
|
||||
|
||||
# checkpoint - we've found all object types, now time to fetch the data
|
||||
logger.info("Starting to fetch CSVs for all object types")
|
||||
# This takes like 30 minutes first time and <2 minutes for updates
|
||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||
sf_client=self.sf_client,
|
||||
object_types=all_object_types,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
# This takes like 10 seconds
|
||||
# This is for testing the rest of the functionality if data has
|
||||
# already been fetched and put in sqlite
|
||||
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
||||
# for object_type in self.parent_object_list:
|
||||
# updated_ids.update(list(find_ids_by_type(object_type)))
|
||||
|
||||
# This takes 10-70 minutes first time (idk why the range is so big)
|
||||
total_types = len(object_type_to_csv_path)
|
||||
logger.info(f"Starting to process {total_types} object types")
|
||||
|
||||
for i, (object_type, csv_paths) in enumerate(
|
||||
object_type_to_csv_path.items(), 1
|
||||
):
|
||||
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
||||
# If path is None, it means it failed to fetch the csv
|
||||
if csv_paths is None:
|
||||
continue
|
||||
# Go through each csv path and use it to update the db
|
||||
for csv_path in csv_paths:
|
||||
logger.debug(f"Updating {object_type} with {csv_path}")
|
||||
new_ids = update_sf_db_with_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
)
|
||||
updated_ids.update(new_ids)
|
||||
logger.debug(
|
||||
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||
)
|
||||
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||
logger.info(
|
||||
f"Starting to process parent objects of types: {self.parent_object_list}"
|
||||
)
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_processed = 0
|
||||
# Takes 15-20 seconds per batch
|
||||
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
||||
updated_ids=list(updated_ids),
|
||||
parent_types=self.parent_object_list,
|
||||
):
|
||||
logger.info(
|
||||
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
if not (parent_object := get_record(parent_id, parent_type)):
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
docs_to_yield.append(
|
||||
convert_sf_object_to_doc(
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
)
|
||||
docs_processed += 1
|
||||
|
||||
if len(docs_to_yield) >= self.batch_size:
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
@@ -245,19 +170,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
return self._fetch_from_salesforce(start=start, end=end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
@@ -274,9 +193,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
||||
)
|
||||
import time
|
||||
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
@@ -285,5 +204,20 @@ if __name__ == "__main__":
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
start_time = time.time()
|
||||
doc_count = 0
|
||||
section_count = 0
|
||||
text_count = 0
|
||||
for doc_batch in connector.load_from_state():
|
||||
doc_count += len(doc_batch)
|
||||
print(f"doc_count: {doc_count}")
|
||||
for doc in doc_batch:
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
text_count += len(section.text)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Doc count: {doc_count}")
|
||||
print(f"Section count: {section_count}")
|
||||
print(f"Text count: {text_count}")
|
||||
print(f"Time taken: {end_time - start_time}")
|
||||
|
||||
184
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
184
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import re
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: dict | list) -> dict | list:
|
||||
"""Clean and transform Salesforce API response data by recursively:
|
||||
1. Extracting records from the response if present
|
||||
2. Merging attributes into the main dictionary
|
||||
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
|
||||
4. Removing '__c' suffix from custom field names
|
||||
5. Removing None values and empty containers
|
||||
|
||||
Args:
|
||||
data: A dictionary or list from Salesforce API response
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary or list with transformed keys and filtered values
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
# remove the custom object indicator for display
|
||||
if "__c" in key:
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_value:
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_item:
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
|
||||
"""Convert a nested dictionary or list into a human-readable string format.
|
||||
|
||||
Recursively traverses the data structure and formats it with:
|
||||
- Key-value pairs on separate lines
|
||||
- Nested structures indented for readability
|
||||
- Lists and dictionaries handled with appropriate formatting
|
||||
|
||||
Args:
|
||||
data: The dictionary or list to convert
|
||||
indent: Number of spaces to indent (default: 0)
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the data structure
|
||||
"""
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent + 2))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _extract_dict_text(raw_dict: dict) -> str:
|
||||
"""Extract text from a Salesforce API response dictionary by:
|
||||
1. Cleaning the dictionary
|
||||
2. Converting the cleaned dictionary to natural language
|
||||
"""
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_for_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
|
||||
|
||||
def _extract_primary_owners(
|
||||
sf_object: SalesforceObject,
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
logger.warning(f"No LastModifiedById found for {sf_object.id}")
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
|
||||
return None
|
||||
|
||||
user_data = last_modified_by.data
|
||||
expert_info = BasicExpertInfo(
|
||||
first_name=user_data.get("FirstName"),
|
||||
last_name=user_data.get("LastName"),
|
||||
email=user_data.get("Email"),
|
||||
display_name=user_data.get("Name"),
|
||||
)
|
||||
|
||||
# Check if all fields are None
|
||||
if all(
|
||||
value is None
|
||||
for value in [
|
||||
expert_info.first_name,
|
||||
expert_info.last_name,
|
||||
expert_info.email,
|
||||
expert_info.display_name,
|
||||
]
|
||||
):
|
||||
logger.warning(f"No identifying information found for user {user_data}")
|
||||
return None
|
||||
|
||||
return [expert_info]
|
||||
|
||||
|
||||
def convert_sf_object_to_doc(
|
||||
sf_object: SalesforceObject,
|
||||
sf_instance: str,
|
||||
) -> Document:
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
|
||||
sections = [_extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(_extract_section(child_object, base_url))
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=_extract_primary_owners(sf_object),
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
213
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
213
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pytz import UTC
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
if start is None or end is None:
|
||||
return ""
|
||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||
return (
|
||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
||||
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
||||
return sf_object.describe()
|
||||
|
||||
|
||||
def _is_valid_child_object(
|
||||
sf_client: Salesforce, child_relationship: dict[str, Any]
|
||||
) -> bool:
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
|
||||
child_object_types = set()
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if _is_valid_child_object(sf_client, child_relationship):
|
||||
logger.debug(
|
||||
f"Found valid child object {child_relationship['childSObject']}"
|
||||
)
|
||||
child_object_types.add(child_relationship["childSObject"])
|
||||
return child_object_types
|
||||
|
||||
|
||||
def _get_all_queryable_fields_of_sf_type(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
) -> list[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
field_names_to_remove: set[str] = set()
|
||||
for field in fields:
|
||||
if compound_field_name := field.get("compoundFieldName"):
|
||||
# We do want to get name fields even if they are compound
|
||||
if not field.get("nameField"):
|
||||
field_names_to_remove.add(compound_field_name)
|
||||
if field.get("type", "base64") == "base64":
|
||||
continue
|
||||
if field_name := field.get("name"):
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - field_names_to_remove)
|
||||
|
||||
|
||||
def _check_if_object_type_is_empty(
|
||||
sf_client: Salesforce, sf_type: str, time_filter: str
|
||||
) -> bool:
|
||||
"""
|
||||
Use the rest api to check to make sure the query will result in a non-empty response
|
||||
"""
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1"
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
if "OPERATION_TOO_LARGE" not in str(e):
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
||||
# Check if the csv already exists
|
||||
if os.path.exists(get_object_type_path(sf_type)):
|
||||
existing_csvs = [
|
||||
os.path.join(get_object_type_path(sf_type), f)
|
||||
for f in os.listdir(get_object_type_path(sf_type))
|
||||
if f.endswith(".csv")
|
||||
]
|
||||
# If the csv already exists, return the path
|
||||
# This is likely due to a previous run that failed
|
||||
# after downloading the csv but before the data was
|
||||
# written to the db
|
||||
if existing_csvs:
|
||||
return existing_csvs
|
||||
return None
|
||||
|
||||
|
||||
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
||||
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def _bulk_retrieve_from_salesforce(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
time_filter: str,
|
||||
) -> tuple[str, list[str] | None]:
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
|
||||
return sf_type, None
|
||||
|
||||
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||
return sf_type, existing_csvs
|
||||
|
||||
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
||||
|
||||
bulk_2_handler = SFBulk2Handler(
|
||||
session_id=sf_client.session_id,
|
||||
bulk2_url=sf_client.bulk2_url,
|
||||
proxies=sf_client.proxies,
|
||||
session=sf_client.session,
|
||||
)
|
||||
bulk_2_type = SFBulk2Type(
|
||||
object_name=sf_type,
|
||||
bulk2_url=bulk_2_handler.bulk2_url,
|
||||
headers=bulk_2_handler.headers,
|
||||
session=bulk_2_handler.session,
|
||||
)
|
||||
|
||||
logger.info(f"Downloading {sf_type}")
|
||||
logger.info(f"Query: {query}")
|
||||
|
||||
try:
|
||||
# This downloads the file to a file in the target path with a random name
|
||||
results = bulk_2_type.download(
|
||||
query=query,
|
||||
path=get_object_type_path(sf_type),
|
||||
max_records=1000000,
|
||||
)
|
||||
all_download_paths = [result["file"] for result in results]
|
||||
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
||||
return sf_type, all_download_paths
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
||||
return sf_type, None
|
||||
|
||||
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
object_types: set[str],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> dict[str, list[str] | None]:
|
||||
"""
|
||||
Fetches all the csvs in parallel for the given object types
|
||||
Returns a dict of (sf_type, full_download_path)
|
||||
"""
|
||||
time_filter = _build_time_filter_for_salesforce(start, end)
|
||||
time_filter_for_each_object_type = {}
|
||||
# We do this outside of the thread pool executor because this requires
|
||||
# a database connection and we don't want to block the thread pool
|
||||
# executor from running
|
||||
for sf_type in object_types:
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
if has_at_least_one_object_of_type(sf_type):
|
||||
time_filter_for_each_object_type[sf_type] = time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = ""
|
||||
|
||||
# Run the bulk retrieve in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = executor.map(
|
||||
lambda object_type: _bulk_retrieve_from_salesforce(
|
||||
sf_client=sf_client,
|
||||
sf_type=object_type,
|
||||
time_filter=time_filter_for_each_object_type[object_type],
|
||||
),
|
||||
object_types,
|
||||
)
|
||||
return dict(results)
|
||||
@@ -0,0 +1,737 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
get_affected_parent_ids_by_type,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
update_sf_db_with_csv,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
_VALID_SALESFORCE_IDS = [
|
||||
"001bm00000fd9Z3AAI",
|
||||
"001bm00000fdYTdAAM",
|
||||
"001bm00000fdYTeAAM",
|
||||
"001bm00000fdYTfAAM",
|
||||
"001bm00000fdYTgAAM",
|
||||
"001bm00000fdYThAAM",
|
||||
"001bm00000fdYTiAAM",
|
||||
"001bm00000fdYTjAAM",
|
||||
"001bm00000fdYTkAAM",
|
||||
"001bm00000fdYTlAAM",
|
||||
"001bm00000fdYTmAAM",
|
||||
"001bm00000fdYTnAAM",
|
||||
"001bm00000fdYToAAM",
|
||||
"500bm00000XoOxtAAF",
|
||||
"500bm00000XoOxuAAF",
|
||||
"500bm00000XoOxvAAF",
|
||||
"500bm00000XoOxwAAF",
|
||||
"500bm00000XoOxxAAF",
|
||||
"500bm00000XoOxyAAF",
|
||||
"500bm00000XoOxzAAF",
|
||||
"500bm00000XoOy0AAF",
|
||||
"500bm00000XoOy1AAF",
|
||||
"500bm00000XoOy2AAF",
|
||||
"500bm00000XoOy3AAF",
|
||||
"500bm00000XoOy4AAF",
|
||||
"500bm00000XoOy5AAF",
|
||||
"500bm00000XoOy6AAF",
|
||||
"500bm00000XoOy7AAF",
|
||||
"500bm00000XoOy8AAF",
|
||||
"500bm00000XoOy9AAF",
|
||||
"500bm00000XoOyAAAV",
|
||||
"500bm00000XoOyBAAV",
|
||||
"500bm00000XoOyCAAV",
|
||||
"500bm00000XoOyDAAV",
|
||||
"500bm00000XoOyEAAV",
|
||||
"500bm00000XoOyFAAV",
|
||||
"500bm00000XoOyGAAV",
|
||||
"500bm00000XoOyHAAV",
|
||||
"500bm00000XoOyIAAV",
|
||||
"003bm00000EjHCjAAN",
|
||||
"003bm00000EjHCkAAN",
|
||||
"003bm00000EjHClAAN",
|
||||
"003bm00000EjHCmAAN",
|
||||
"003bm00000EjHCnAAN",
|
||||
"003bm00000EjHCoAAN",
|
||||
"003bm00000EjHCpAAN",
|
||||
"003bm00000EjHCqAAN",
|
||||
"003bm00000EjHCrAAN",
|
||||
"003bm00000EjHCsAAN",
|
||||
"003bm00000EjHCtAAN",
|
||||
"003bm00000EjHCuAAN",
|
||||
"003bm00000EjHCvAAN",
|
||||
"003bm00000EjHCwAAN",
|
||||
"003bm00000EjHCxAAN",
|
||||
"003bm00000EjHCyAAN",
|
||||
"003bm00000EjHCzAAN",
|
||||
"003bm00000EjHD0AAN",
|
||||
"003bm00000EjHD1AAN",
|
||||
"003bm00000EjHD2AAN",
|
||||
"550bm00000EXc2tAAD",
|
||||
"006bm000006kyDpAAI",
|
||||
"006bm000006kyDqAAI",
|
||||
"006bm000006kyDrAAI",
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
"006bm000006kyDvAAI",
|
||||
"006bm000006kyDwAAI",
|
||||
"006bm000006kyDxAAI",
|
||||
"006bm000006kyDyAAI",
|
||||
"006bm000006kyDzAAI",
|
||||
"006bm000006kyE0AAI",
|
||||
"006bm000006kyE1AAI",
|
||||
"006bm000006kyE2AAI",
|
||||
"006bm000006kyE3AAI",
|
||||
"006bm000006kyE4AAI",
|
||||
"006bm000006kyE5AAI",
|
||||
"006bm000006kyE6AAI",
|
||||
"006bm000006kyE7AAI",
|
||||
"006bm000006kyE8AAI",
|
||||
"006bm000006kyE9AAI",
|
||||
"006bm000006kyEAAAY",
|
||||
"006bm000006kyEBAAY",
|
||||
"006bm000006kyECAAY",
|
||||
"006bm000006kyEDAAY",
|
||||
"006bm000006kyEEAAY",
|
||||
"006bm000006kyEFAAY",
|
||||
"006bm000006kyEGAAY",
|
||||
"006bm000006kyEHAAY",
|
||||
"006bm000006kyEIAAY",
|
||||
"006bm000006kyEJAAY",
|
||||
"005bm000009zy0TAAQ",
|
||||
"005bm000009zy25AAA",
|
||||
"005bm000009zy26AAA",
|
||||
"005bm000009zy28AAA",
|
||||
"005bm000009zy29AAA",
|
||||
"005bm000009zy2AAAQ",
|
||||
"005bm000009zy2BAAQ",
|
||||
]
|
||||
|
||||
|
||||
def clear_sf_db() -> None:
|
||||
"""
|
||||
Clears the SF DB by deleting all files in the data directory.
|
||||
"""
|
||||
shutil.rmtree(BASE_DATA_PATH)
|
||||
|
||||
|
||||
def create_csv_file(
|
||||
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||
) -> None:
|
||||
"""
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Get all unique fields from records
|
||||
fields: set[str] = set()
|
||||
for record in records:
|
||||
fields.update(record.keys())
|
||||
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||
|
||||
# Create CSV file
|
||||
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
for record in records:
|
||||
writer.writerow(record)
|
||||
|
||||
# Update the database with the CSV
|
||||
update_sf_db_with_csv(object_type, csv_path)
|
||||
|
||||
|
||||
def create_csv_with_example_data() -> None:
|
||||
"""
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[3],
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[4],
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[5],
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[6],
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[7],
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"FirstName": "John",
|
||||
"LastName": "Doe",
|
||||
"Email": "john.doe@acme.com",
|
||||
"Title": "CEO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[41],
|
||||
"FirstName": "Jane",
|
||||
"LastName": "Smith",
|
||||
"Email": "jane.smith@acme.com",
|
||||
"Title": "CTO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[42],
|
||||
"FirstName": "Bob",
|
||||
"LastName": "Johnson",
|
||||
"Email": "bob.j@globex.com",
|
||||
"Title": "Sales Director",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[43],
|
||||
"FirstName": "Sarah",
|
||||
"LastName": "Chen",
|
||||
"Email": "sarah.chen@techcorp.com",
|
||||
"Title": "Product Manager",
|
||||
"Phone": "415-555-0101",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[44],
|
||||
"FirstName": "Michael",
|
||||
"LastName": "Rodriguez",
|
||||
"Email": "m.rodriguez@biomed.com",
|
||||
"Title": "Research Director",
|
||||
"Phone": "617-555-0202",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[45],
|
||||
"FirstName": "Emily",
|
||||
"LastName": "Green",
|
||||
"Email": "emily.g@greenenergy.com",
|
||||
"Title": "Sustainability Lead",
|
||||
"Phone": "503-555-0303",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[46],
|
||||
"FirstName": "David",
|
||||
"LastName": "Kim",
|
||||
"Email": "david.kim@dataflow.com",
|
||||
"Title": "Data Scientist",
|
||||
"Phone": "206-555-0404",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[47],
|
||||
"FirstName": "Rachel",
|
||||
"LastName": "Taylor",
|
||||
"Email": "r.taylor@cloudnine.com",
|
||||
"Title": "Cloud Architect",
|
||||
"Phone": "303-555-0505",
|
||||
},
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"Name": "Acme Server Upgrade",
|
||||
"Amount": 50000,
|
||||
"Stage": "Prospecting",
|
||||
"CloseDate": "2024-06-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[63],
|
||||
"Name": "Globex Manufacturing Line",
|
||||
"Amount": 150000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-03-15",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[64],
|
||||
"Name": "Initech Software License",
|
||||
"Amount": 75000,
|
||||
"Stage": "Closed Won",
|
||||
"CloseDate": "2024-01-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[65],
|
||||
"Name": "TechCorp AI Implementation",
|
||||
"Amount": 250000,
|
||||
"Stage": "Needs Analysis",
|
||||
"CloseDate": "2024-08-15",
|
||||
"Probability": 60,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[66],
|
||||
"Name": "BioMed Lab Equipment",
|
||||
"Amount": 500000,
|
||||
"Stage": "Value Proposition",
|
||||
"CloseDate": "2024-09-30",
|
||||
"Probability": 75,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[67],
|
||||
"Name": "Green Energy Solar Project",
|
||||
"Amount": 750000,
|
||||
"Stage": "Proposal",
|
||||
"CloseDate": "2024-07-15",
|
||||
"Probability": 80,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[68],
|
||||
"Name": "DataFlow Analytics Platform",
|
||||
"Amount": 180000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-05-30",
|
||||
"Probability": 90,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[69],
|
||||
"Name": "Cloud Nine Infrastructure",
|
||||
"Amount": 300000,
|
||||
"Stage": "Qualification",
|
||||
"CloseDate": "2024-10-15",
|
||||
"Probability": 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create CSV files for each object type
|
||||
for object_type, records in example_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
"""
|
||||
Tests querying functionality by verifying:
|
||||
1. All expected Account IDs are found
|
||||
2. Each Account's data matches what was inserted
|
||||
"""
|
||||
# Expected test data for verification
|
||||
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||
_VALID_SALESFORCE_IDS[0]: {
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[1]: {
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[2]: {
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[3]: {
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[4]: {
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[5]: {
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[6]: {
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[7]: {
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
expected_accounts
|
||||
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||
assert set(account_ids) == set(
|
||||
expected_accounts.keys()
|
||||
), "Found account IDs don't match expected IDs"
|
||||
|
||||
# Verify each account's data
|
||||
for acc_id in account_ids:
|
||||
combined = get_record(acc_id)
|
||||
assert combined is not None, f"Could not find account {acc_id}"
|
||||
|
||||
expected = expected_accounts[acc_id]
|
||||
|
||||
# Verify account data matches
|
||||
for key, value in expected.items():
|
||||
value = str(value)
|
||||
assert (
|
||||
combined.data[key] == value
|
||||
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||
|
||||
print("All query tests passed successfully!")
|
||||
|
||||
|
||||
def test_upsert() -> None:
|
||||
"""
|
||||
Tests upsert functionality by:
|
||||
1. Updating an existing account
|
||||
2. Creating a new account
|
||||
3. Verifying both operations were successful
|
||||
"""
|
||||
# Create CSV for updating an existing account and adding a new one
|
||||
update_data: list[dict[str, str | int]] = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc. Updated",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
"Description": "Updated company info",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "New Company Inc.",
|
||||
"BillingCity": "Miami",
|
||||
"Industry": "Finance",
|
||||
"AnnualRevenue": 1000000,
|
||||
},
|
||||
]
|
||||
|
||||
create_csv_file("Account", update_data, "update_data.csv")
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||
assert updated_record is not None, "Updated record not found"
|
||||
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||
assert (
|
||||
updated_record.data["Description"] == "Updated company info"
|
||||
), "Description not added"
|
||||
|
||||
# Verify the new record was created
|
||||
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||
assert new_record is not None, "New record not found"
|
||||
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||
|
||||
print("All upsert tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationships() -> None:
|
||||
"""
|
||||
Tests relationship shelf updates and queries by:
|
||||
1. Creating test data with relationships
|
||||
2. Verifying the relationships are correctly stored
|
||||
3. Testing relationship queries
|
||||
"""
|
||||
# Create test data for each object type
|
||||
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||
"Case": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[13],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[14],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 2",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[48],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Test Opportunity",
|
||||
"Amount": 100000,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for each object type
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records, "relationship_test.csv")
|
||||
|
||||
# Test relationship queries
|
||||
# All these objects should be children of Acme Inc.
|
||||
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||
), "Opportunity not found in relationship"
|
||||
|
||||
# Test querying relationships for a different account (should be empty)
|
||||
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert (
|
||||
len(other_account_children) == 0
|
||||
), "Expected no children for different account"
|
||||
|
||||
print("All relationship tests passed successfully!")
|
||||
|
||||
|
||||
def test_account_with_children() -> None:
|
||||
"""
|
||||
Tests querying all accounts and retrieving their child objects.
|
||||
This test verifies that:
|
||||
1. All accounts can be retrieved
|
||||
2. Child objects are correctly linked
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
for account_id in account_ids:
|
||||
account = get_record(account_id)
|
||||
assert account is not None, f"Could not find account {account_id}"
|
||||
|
||||
# Get all child objects
|
||||
child_ids = get_child_ids(account_id)
|
||||
|
||||
# For Acme Inc., verify specific relationships
|
||||
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||
assert (
|
||||
len(child_ids) == 4
|
||||
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||
|
||||
# Get all child records
|
||||
child_records = []
|
||||
for child_id in child_ids:
|
||||
child_record = get_record(child_id)
|
||||
if child_record is not None:
|
||||
child_records.append(child_record)
|
||||
# Verify Cases
|
||||
cases = [r for r in child_records if r.type == "Case"]
|
||||
assert (
|
||||
len(cases) == 2
|
||||
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||
case_subjects = {case.data["Subject"] for case in cases}
|
||||
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||
|
||||
# Verify Contacts
|
||||
contacts = [r for r in child_records if r.type == "Contact"]
|
||||
assert (
|
||||
len(contacts) == 1
|
||||
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||
contact = contacts[0]
|
||||
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||
|
||||
# Verify Opportunities
|
||||
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||
assert (
|
||||
len(opportunities) == 1
|
||||
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||
opportunity = opportunities[0]
|
||||
assert (
|
||||
opportunity.data["Name"] == "Test Opportunity"
|
||||
), "Opportunity name mismatch"
|
||||
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||
|
||||
print("All account with children tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationship_updates() -> None:
|
||||
"""
|
||||
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||
This test verifies:
|
||||
1. Initial relationship is created correctly
|
||||
2. When parent reference is updated, old relationship is removed
|
||||
3. New relationship is created correctly
|
||||
"""
|
||||
# Create initial test data - Contact linked to Acme Inc.
|
||||
initial_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||
|
||||
# Verify initial relationship
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||
), "Initial relationship not created"
|
||||
|
||||
# Update contact to be linked to Globex Corp instead
|
||||
updated_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||
|
||||
# Verify old relationship is removed
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||
), "Old relationship not removed"
|
||||
|
||||
# Verify new relationship is created
|
||||
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||
|
||||
print("All relationship update tests passed successfully!")
|
||||
|
||||
|
||||
def test_get_affected_parent_ids() -> None:
|
||||
"""
|
||||
Tests get_affected_parent_ids functionality by verifying:
|
||||
1. IDs that are directly in the parent_types list are included
|
||||
2. IDs that have children in the updated_ids list are included
|
||||
3. IDs that are neither of the above are not included
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Parent Account 2",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Not Affected Account",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Child",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for test data
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||
), "Unaffected ID incorrectly included"
|
||||
|
||||
# Test Case 4: No matches
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Opportunity"] # Wrong type
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||
|
||||
print("All get_affected_parent_ids tests passed successfully!")
|
||||
|
||||
|
||||
def main_build() -> None:
|
||||
clear_sf_db()
|
||||
create_csv_with_example_data()
|
||||
test_query()
|
||||
test_upsert()
|
||||
test_relationships()
|
||||
test_account_with_children()
|
||||
test_relationship_updates()
|
||||
test_get_affected_parent_ids()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_build()
|
||||
@@ -0,0 +1,209 @@
|
||||
import csv
|
||||
import shelve
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_child_to_parent_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_parent_to_child_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _update_relationship_shelves(
|
||||
child_id: str,
|
||||
parent_ids: set[str],
|
||||
) -> None:
|
||||
"""Update the relationship shelf when a record is updated."""
|
||||
try:
|
||||
# Convert child_id to string once
|
||||
str_child_id = str(child_id)
|
||||
|
||||
# First update child to parent mapping
|
||||
with shelve.open(
|
||||
get_child_to_parent_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as child_to_parent_db:
|
||||
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
|
||||
child_to_parent_db[str_child_id] = list(parent_ids)
|
||||
|
||||
# Calculate differences outside the next context manager
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Only sync once at the end
|
||||
child_to_parent_db.sync()
|
||||
|
||||
# Then update parent to child mapping in a single transaction
|
||||
if not parent_ids_to_remove and not parent_ids_to_add:
|
||||
return
|
||||
with shelve.open(
|
||||
get_parent_to_child_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as parent_to_child_db:
|
||||
# Process all removals first
|
||||
for parent_id in parent_ids_to_remove:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
if str_child_id in existing_children:
|
||||
existing_children.remove(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Then process all additions
|
||||
for parent_id in parent_ids_to_add:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
existing_children.add(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Single sync at the end
|
||||
parent_to_child_db.sync()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship shelves: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID.
|
||||
|
||||
Args:
|
||||
parent_id: The ID of the parent object
|
||||
|
||||
Returns:
|
||||
A set of child object IDs
|
||||
"""
|
||||
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
|
||||
return set(parent_to_child_db.get(parent_id, []))
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using shelve storage."""
|
||||
updated_ids = []
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
|
||||
# First read the CSV to get all the data
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
# Update relationship shelves for any parent references
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
_update_relationship_shelves(id, parent_ids)
|
||||
for field in field_to_remove:
|
||||
# We use this to extract the Primary Owner later
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update the main object shelf
|
||||
with shelve.open(shelf_path) as object_type_db:
|
||||
object_type_db[id] = row
|
||||
# Update the ID-to-type mapping shelf
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
id_type_db[id] = object_type
|
||||
|
||||
updated_ids.append(id)
|
||||
|
||||
# os.remove(csv_download_path)
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
# Look up the object type from the ID-to-type mapping
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
if object_id not in id_type_db:
|
||||
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
|
||||
return None
|
||||
return id_type_db[object_id]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""
|
||||
Retrieve the record and return it as a SalesforceObject.
|
||||
The object type will be looked up from the ID-to-type mapping shelf.
|
||||
"""
|
||||
if object_type is None:
|
||||
if not (object_type := get_type_from_id(object_id)):
|
||||
return None
|
||||
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
with shelve.open(shelf_path) as db:
|
||||
if object_id not in db:
|
||||
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
|
||||
return None
|
||||
data = db[object_id]
|
||||
return SalesforceObject(
|
||||
id=object_id,
|
||||
type=object_type,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""
|
||||
Find all object IDs for rows of the specified type.
|
||||
"""
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
try:
|
||||
with shelve.open(shelf_path) as db:
|
||||
return list(db.keys())
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: set[str], parent_types: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
|
||||
or have children in the updated_ids.
|
||||
|
||||
Args:
|
||||
updated_ids: List of IDs that were updated
|
||||
parent_types: List of object types to filter by
|
||||
|
||||
Returns:
|
||||
A dictionary of IDs that match the criteria
|
||||
"""
|
||||
affected_ids_by_type: dict[str, set[str]] = {}
|
||||
|
||||
# Check each updated ID
|
||||
for updated_id in updated_ids:
|
||||
# Add the ID itself if it's of a parent type
|
||||
updated_type = get_type_from_id(updated_id)
|
||||
if updated_type in parent_types:
|
||||
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
|
||||
continue
|
||||
|
||||
# Get parents of this ID and add them if they're of a parent type
|
||||
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
|
||||
parent_ids = child_to_parent_db.get(updated_id, [])
|
||||
for parent_id in parent_ids:
|
||||
parent_type = get_type_from_id(parent_id)
|
||||
if parent_type in parent_types:
|
||||
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
|
||||
|
||||
return affected_ids_by_type
|
||||
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
|
||||
def get_object_shelf_path(object_type: str) -> str:
|
||||
"""Get the path to the shelf file for a specific object type."""
|
||||
base_path = get_object_type_path(object_type)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
return os.path.join(base_path, "data.shelf")
|
||||
|
||||
|
||||
def get_id_type_shelf_path() -> str:
|
||||
"""Get the path to the ID-to-type mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_parent_to_child_shelf_path() -> str:
|
||||
"""Get the path to the parent-to-child mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_child_to_parent_shelf_path() -> str:
|
||||
"""Get the path to the child-to-parent mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")
|
||||
@@ -0,0 +1,737 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
get_affected_parent_ids_by_type,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
update_sf_db_with_csv,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
_VALID_SALESFORCE_IDS = [
|
||||
"001bm00000fd9Z3AAI",
|
||||
"001bm00000fdYTdAAM",
|
||||
"001bm00000fdYTeAAM",
|
||||
"001bm00000fdYTfAAM",
|
||||
"001bm00000fdYTgAAM",
|
||||
"001bm00000fdYThAAM",
|
||||
"001bm00000fdYTiAAM",
|
||||
"001bm00000fdYTjAAM",
|
||||
"001bm00000fdYTkAAM",
|
||||
"001bm00000fdYTlAAM",
|
||||
"001bm00000fdYTmAAM",
|
||||
"001bm00000fdYTnAAM",
|
||||
"001bm00000fdYToAAM",
|
||||
"500bm00000XoOxtAAF",
|
||||
"500bm00000XoOxuAAF",
|
||||
"500bm00000XoOxvAAF",
|
||||
"500bm00000XoOxwAAF",
|
||||
"500bm00000XoOxxAAF",
|
||||
"500bm00000XoOxyAAF",
|
||||
"500bm00000XoOxzAAF",
|
||||
"500bm00000XoOy0AAF",
|
||||
"500bm00000XoOy1AAF",
|
||||
"500bm00000XoOy2AAF",
|
||||
"500bm00000XoOy3AAF",
|
||||
"500bm00000XoOy4AAF",
|
||||
"500bm00000XoOy5AAF",
|
||||
"500bm00000XoOy6AAF",
|
||||
"500bm00000XoOy7AAF",
|
||||
"500bm00000XoOy8AAF",
|
||||
"500bm00000XoOy9AAF",
|
||||
"500bm00000XoOyAAAV",
|
||||
"500bm00000XoOyBAAV",
|
||||
"500bm00000XoOyCAAV",
|
||||
"500bm00000XoOyDAAV",
|
||||
"500bm00000XoOyEAAV",
|
||||
"500bm00000XoOyFAAV",
|
||||
"500bm00000XoOyGAAV",
|
||||
"500bm00000XoOyHAAV",
|
||||
"500bm00000XoOyIAAV",
|
||||
"003bm00000EjHCjAAN",
|
||||
"003bm00000EjHCkAAN",
|
||||
"003bm00000EjHClAAN",
|
||||
"003bm00000EjHCmAAN",
|
||||
"003bm00000EjHCnAAN",
|
||||
"003bm00000EjHCoAAN",
|
||||
"003bm00000EjHCpAAN",
|
||||
"003bm00000EjHCqAAN",
|
||||
"003bm00000EjHCrAAN",
|
||||
"003bm00000EjHCsAAN",
|
||||
"003bm00000EjHCtAAN",
|
||||
"003bm00000EjHCuAAN",
|
||||
"003bm00000EjHCvAAN",
|
||||
"003bm00000EjHCwAAN",
|
||||
"003bm00000EjHCxAAN",
|
||||
"003bm00000EjHCyAAN",
|
||||
"003bm00000EjHCzAAN",
|
||||
"003bm00000EjHD0AAN",
|
||||
"003bm00000EjHD1AAN",
|
||||
"003bm00000EjHD2AAN",
|
||||
"550bm00000EXc2tAAD",
|
||||
"006bm000006kyDpAAI",
|
||||
"006bm000006kyDqAAI",
|
||||
"006bm000006kyDrAAI",
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
"006bm000006kyDvAAI",
|
||||
"006bm000006kyDwAAI",
|
||||
"006bm000006kyDxAAI",
|
||||
"006bm000006kyDyAAI",
|
||||
"006bm000006kyDzAAI",
|
||||
"006bm000006kyE0AAI",
|
||||
"006bm000006kyE1AAI",
|
||||
"006bm000006kyE2AAI",
|
||||
"006bm000006kyE3AAI",
|
||||
"006bm000006kyE4AAI",
|
||||
"006bm000006kyE5AAI",
|
||||
"006bm000006kyE6AAI",
|
||||
"006bm000006kyE7AAI",
|
||||
"006bm000006kyE8AAI",
|
||||
"006bm000006kyE9AAI",
|
||||
"006bm000006kyEAAAY",
|
||||
"006bm000006kyEBAAY",
|
||||
"006bm000006kyECAAY",
|
||||
"006bm000006kyEDAAY",
|
||||
"006bm000006kyEEAAY",
|
||||
"006bm000006kyEFAAY",
|
||||
"006bm000006kyEGAAY",
|
||||
"006bm000006kyEHAAY",
|
||||
"006bm000006kyEIAAY",
|
||||
"006bm000006kyEJAAY",
|
||||
"005bm000009zy0TAAQ",
|
||||
"005bm000009zy25AAA",
|
||||
"005bm000009zy26AAA",
|
||||
"005bm000009zy28AAA",
|
||||
"005bm000009zy29AAA",
|
||||
"005bm000009zy2AAAQ",
|
||||
"005bm000009zy2BAAQ",
|
||||
]
|
||||
|
||||
|
||||
def clear_sf_db() -> None:
|
||||
"""
|
||||
Clears the SF DB by deleting all files in the data directory.
|
||||
"""
|
||||
shutil.rmtree(BASE_DATA_PATH)
|
||||
|
||||
|
||||
def create_csv_file(
|
||||
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||
) -> None:
|
||||
"""
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Get all unique fields from records
|
||||
fields: set[str] = set()
|
||||
for record in records:
|
||||
fields.update(record.keys())
|
||||
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||
|
||||
# Create CSV file
|
||||
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
for record in records:
|
||||
writer.writerow(record)
|
||||
|
||||
# Update the database with the CSV
|
||||
update_sf_db_with_csv(object_type, csv_path)
|
||||
|
||||
|
||||
def create_csv_with_example_data() -> None:
|
||||
"""
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[3],
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[4],
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[5],
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[6],
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[7],
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"FirstName": "John",
|
||||
"LastName": "Doe",
|
||||
"Email": "john.doe@acme.com",
|
||||
"Title": "CEO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[41],
|
||||
"FirstName": "Jane",
|
||||
"LastName": "Smith",
|
||||
"Email": "jane.smith@acme.com",
|
||||
"Title": "CTO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[42],
|
||||
"FirstName": "Bob",
|
||||
"LastName": "Johnson",
|
||||
"Email": "bob.j@globex.com",
|
||||
"Title": "Sales Director",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[43],
|
||||
"FirstName": "Sarah",
|
||||
"LastName": "Chen",
|
||||
"Email": "sarah.chen@techcorp.com",
|
||||
"Title": "Product Manager",
|
||||
"Phone": "415-555-0101",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[44],
|
||||
"FirstName": "Michael",
|
||||
"LastName": "Rodriguez",
|
||||
"Email": "m.rodriguez@biomed.com",
|
||||
"Title": "Research Director",
|
||||
"Phone": "617-555-0202",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[45],
|
||||
"FirstName": "Emily",
|
||||
"LastName": "Green",
|
||||
"Email": "emily.g@greenenergy.com",
|
||||
"Title": "Sustainability Lead",
|
||||
"Phone": "503-555-0303",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[46],
|
||||
"FirstName": "David",
|
||||
"LastName": "Kim",
|
||||
"Email": "david.kim@dataflow.com",
|
||||
"Title": "Data Scientist",
|
||||
"Phone": "206-555-0404",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[47],
|
||||
"FirstName": "Rachel",
|
||||
"LastName": "Taylor",
|
||||
"Email": "r.taylor@cloudnine.com",
|
||||
"Title": "Cloud Architect",
|
||||
"Phone": "303-555-0505",
|
||||
},
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"Name": "Acme Server Upgrade",
|
||||
"Amount": 50000,
|
||||
"Stage": "Prospecting",
|
||||
"CloseDate": "2024-06-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[63],
|
||||
"Name": "Globex Manufacturing Line",
|
||||
"Amount": 150000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-03-15",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[64],
|
||||
"Name": "Initech Software License",
|
||||
"Amount": 75000,
|
||||
"Stage": "Closed Won",
|
||||
"CloseDate": "2024-01-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[65],
|
||||
"Name": "TechCorp AI Implementation",
|
||||
"Amount": 250000,
|
||||
"Stage": "Needs Analysis",
|
||||
"CloseDate": "2024-08-15",
|
||||
"Probability": 60,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[66],
|
||||
"Name": "BioMed Lab Equipment",
|
||||
"Amount": 500000,
|
||||
"Stage": "Value Proposition",
|
||||
"CloseDate": "2024-09-30",
|
||||
"Probability": 75,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[67],
|
||||
"Name": "Green Energy Solar Project",
|
||||
"Amount": 750000,
|
||||
"Stage": "Proposal",
|
||||
"CloseDate": "2024-07-15",
|
||||
"Probability": 80,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[68],
|
||||
"Name": "DataFlow Analytics Platform",
|
||||
"Amount": 180000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-05-30",
|
||||
"Probability": 90,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[69],
|
||||
"Name": "Cloud Nine Infrastructure",
|
||||
"Amount": 300000,
|
||||
"Stage": "Qualification",
|
||||
"CloseDate": "2024-10-15",
|
||||
"Probability": 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create CSV files for each object type
|
||||
for object_type, records in example_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
"""
|
||||
Tests querying functionality by verifying:
|
||||
1. All expected Account IDs are found
|
||||
2. Each Account's data matches what was inserted
|
||||
"""
|
||||
# Expected test data for verification
|
||||
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||
_VALID_SALESFORCE_IDS[0]: {
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[1]: {
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[2]: {
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[3]: {
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[4]: {
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[5]: {
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[6]: {
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[7]: {
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
expected_accounts
|
||||
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||
assert set(account_ids) == set(
|
||||
expected_accounts.keys()
|
||||
), "Found account IDs don't match expected IDs"
|
||||
|
||||
# Verify each account's data
|
||||
for acc_id in account_ids:
|
||||
combined = get_record(acc_id)
|
||||
assert combined is not None, f"Could not find account {acc_id}"
|
||||
|
||||
expected = expected_accounts[acc_id]
|
||||
|
||||
# Verify account data matches
|
||||
for key, value in expected.items():
|
||||
value = str(value)
|
||||
assert (
|
||||
combined.data[key] == value
|
||||
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||
|
||||
print("All query tests passed successfully!")
|
||||
|
||||
|
||||
def test_upsert() -> None:
|
||||
"""
|
||||
Tests upsert functionality by:
|
||||
1. Updating an existing account
|
||||
2. Creating a new account
|
||||
3. Verifying both operations were successful
|
||||
"""
|
||||
# Create CSV for updating an existing account and adding a new one
|
||||
update_data: list[dict[str, str | int]] = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc. Updated",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
"Description": "Updated company info",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "New Company Inc.",
|
||||
"BillingCity": "Miami",
|
||||
"Industry": "Finance",
|
||||
"AnnualRevenue": 1000000,
|
||||
},
|
||||
]
|
||||
|
||||
create_csv_file("Account", update_data, "update_data.csv")
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||
assert updated_record is not None, "Updated record not found"
|
||||
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||
assert (
|
||||
updated_record.data["Description"] == "Updated company info"
|
||||
), "Description not added"
|
||||
|
||||
# Verify the new record was created
|
||||
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||
assert new_record is not None, "New record not found"
|
||||
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||
|
||||
print("All upsert tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationships() -> None:
|
||||
"""
|
||||
Tests relationship shelf updates and queries by:
|
||||
1. Creating test data with relationships
|
||||
2. Verifying the relationships are correctly stored
|
||||
3. Testing relationship queries
|
||||
"""
|
||||
# Create test data for each object type
|
||||
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||
"Case": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[13],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[14],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 2",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[48],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Test Opportunity",
|
||||
"Amount": 100000,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for each object type
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records, "relationship_test.csv")
|
||||
|
||||
# Test relationship queries
|
||||
# All these objects should be children of Acme Inc.
|
||||
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||
), "Opportunity not found in relationship"
|
||||
|
||||
# Test querying relationships for a different account (should be empty)
|
||||
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert (
|
||||
len(other_account_children) == 0
|
||||
), "Expected no children for different account"
|
||||
|
||||
print("All relationship tests passed successfully!")
|
||||
|
||||
|
||||
def test_account_with_children() -> None:
|
||||
"""
|
||||
Tests querying all accounts and retrieving their child objects.
|
||||
This test verifies that:
|
||||
1. All accounts can be retrieved
|
||||
2. Child objects are correctly linked
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
for account_id in account_ids:
|
||||
account = get_record(account_id)
|
||||
assert account is not None, f"Could not find account {account_id}"
|
||||
|
||||
# Get all child objects
|
||||
child_ids = get_child_ids(account_id)
|
||||
|
||||
# For Acme Inc., verify specific relationships
|
||||
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||
assert (
|
||||
len(child_ids) == 4
|
||||
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||
|
||||
# Get all child records
|
||||
child_records = []
|
||||
for child_id in child_ids:
|
||||
child_record = get_record(child_id)
|
||||
if child_record is not None:
|
||||
child_records.append(child_record)
|
||||
# Verify Cases
|
||||
cases = [r for r in child_records if r.type == "Case"]
|
||||
assert (
|
||||
len(cases) == 2
|
||||
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||
case_subjects = {case.data["Subject"] for case in cases}
|
||||
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||
|
||||
# Verify Contacts
|
||||
contacts = [r for r in child_records if r.type == "Contact"]
|
||||
assert (
|
||||
len(contacts) == 1
|
||||
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||
contact = contacts[0]
|
||||
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||
|
||||
# Verify Opportunities
|
||||
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||
assert (
|
||||
len(opportunities) == 1
|
||||
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||
opportunity = opportunities[0]
|
||||
assert (
|
||||
opportunity.data["Name"] == "Test Opportunity"
|
||||
), "Opportunity name mismatch"
|
||||
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||
|
||||
print("All account with children tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationship_updates() -> None:
|
||||
"""
|
||||
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||
This test verifies:
|
||||
1. Initial relationship is created correctly
|
||||
2. When parent reference is updated, old relationship is removed
|
||||
3. New relationship is created correctly
|
||||
"""
|
||||
# Create initial test data - Contact linked to Acme Inc.
|
||||
initial_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||
|
||||
# Verify initial relationship
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||
), "Initial relationship not created"
|
||||
|
||||
# Update contact to be linked to Globex Corp instead
|
||||
updated_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||
|
||||
# Verify old relationship is removed
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||
), "Old relationship not removed"
|
||||
|
||||
# Verify new relationship is created
|
||||
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||
|
||||
print("All relationship update tests passed successfully!")
|
||||
|
||||
|
||||
def test_get_affected_parent_ids() -> None:
|
||||
"""
|
||||
Tests get_affected_parent_ids functionality by verifying:
|
||||
1. IDs that are directly in the parent_types list are included
|
||||
2. IDs that have children in the updated_ids list are included
|
||||
3. IDs that are neither of the above are not included
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Parent Account 2",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Not Affected Account",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Child",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for test data
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||
), "Unaffected ID incorrectly included"
|
||||
|
||||
# Test Case 4: No matches
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Opportunity"] # Wrong type
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||
|
||||
print("All get_affected_parent_ids tests passed successfully!")
|
||||
|
||||
|
||||
def main_build() -> None:
|
||||
clear_sf_db()
|
||||
create_csv_with_example_data()
|
||||
test_query()
|
||||
test_upsert()
|
||||
test_relationships()
|
||||
test_account_with_children()
|
||||
test_relationship_updates()
|
||||
test_get_affected_parent_ids()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_build()
|
||||
475
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
475
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection(
|
||||
isolation_level: str | None = None,
|
||||
) -> Iterator[sqlite3.Connection]:
|
||||
"""Get a database connection with proper isolation level and error handling.
|
||||
|
||||
Args:
|
||||
isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||
"""
|
||||
# 60 second timeout for locks
|
||||
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
|
||||
|
||||
if isolation_level is not None:
|
||||
conn.isolation_level = isolation_level
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||
# Create database directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||
|
||||
with get_db_connection("EXCLUSIVE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
db_exists = os.path.exists(get_sqlite_db_path())
|
||||
|
||||
if not db_exists:
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
|
||||
# Main table for storing Salesforce objects
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS salesforce_objects (
|
||||
id TEXT PRIMARY KEY,
|
||||
object_type TEXT NOT NULL,
|
||||
data TEXT NOT NULL, -- JSON serialized data
|
||||
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# Table for parent-child relationships with covering index
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id)
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# New table for caching parent-child relationships with object types
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationship_types (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
parent_type TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id, parent_type)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Create a table for User email to ID mapping if it doesn't exist
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS user_email_map (
|
||||
email TEXT PRIMARY KEY,
|
||||
user_id TEXT, -- Nullable to allow for users without IDs
|
||||
FOREIGN KEY (user_id) REFERENCES salesforce_objects(id)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes)
|
||||
def create_index_if_not_exists(index_name: str, create_statement: str) -> None:
|
||||
cursor.execute(
|
||||
f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
cursor.execute(create_statement)
|
||||
|
||||
create_index_if_not_exists(
|
||||
"idx_object_type",
|
||||
"""
|
||||
CREATE INDEX idx_object_type
|
||||
ON salesforce_objects(object_type, id)
|
||||
WHERE object_type IS NOT NULL
|
||||
""",
|
||||
)
|
||||
|
||||
create_index_if_not_exists(
|
||||
"idx_parent_id",
|
||||
"""
|
||||
CREATE INDEX idx_parent_id
|
||||
ON relationships(parent_id, child_id)
|
||||
""",
|
||||
)
|
||||
|
||||
create_index_if_not_exists(
|
||||
"idx_child_parent",
|
||||
"""
|
||||
CREATE INDEX idx_child_parent
|
||||
ON relationships(child_id)
|
||||
WHERE child_id IS NOT NULL
|
||||
""",
|
||||
)
|
||||
|
||||
create_index_if_not_exists(
|
||||
"idx_relationship_types_lookup",
|
||||
"""
|
||||
CREATE INDEX idx_relationship_types_lookup
|
||||
ON relationship_types(parent_type, child_id, parent_id)
|
||||
""",
|
||||
)
|
||||
|
||||
# Analyze tables to help query planner
|
||||
cursor.execute("ANALYZE relationships")
|
||||
cursor.execute("ANALYZE salesforce_objects")
|
||||
cursor.execute("ANALYZE relationship_types")
|
||||
cursor.execute("ANALYZE user_email_map")
|
||||
|
||||
# If database already existed but user_email_map needs to be populated
|
||||
cursor.execute("SELECT COUNT(*) FROM user_email_map")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
_update_user_email_map(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _update_relationship_tables(
|
||||
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
|
||||
) -> None:
|
||||
"""Update the relationship tables when a record is updated.
|
||||
|
||||
Args:
|
||||
conn: The database connection to use (must be in a transaction)
|
||||
child_id: The ID of the child record
|
||||
parent_ids: Set of parent IDs to link to
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get existing parent IDs
|
||||
cursor.execute(
|
||||
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
|
||||
)
|
||||
old_parent_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
# Calculate differences
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Remove old relationships
|
||||
if parent_ids_to_remove:
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
# Also remove from relationship_types
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
|
||||
# Add new relationships
|
||||
if parent_ids_to_add:
|
||||
# First add to relationships table
|
||||
cursor.executemany(
|
||||
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
|
||||
[(child_id, pid) for pid in parent_ids_to_add],
|
||||
)
|
||||
|
||||
# Then get the types of the parent objects and add to relationship_types
|
||||
for parent_id in parent_ids_to_add:
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
parent_type = result[0]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO relationship_types (child_id, parent_id, parent_type)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(child_id, parent_id, parent_type),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship tables: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def _update_user_email_map(conn: sqlite3.Connection) -> None:
|
||||
"""Update the user_email_map table with current User objects.
|
||||
Called internally by update_sf_db_with_csv when User objects are updated.
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO user_email_map (email, user_id)
|
||||
SELECT json_extract(data, '$.Email'), id
|
||||
FROM salesforce_objects
|
||||
WHERE object_type = 'User'
|
||||
AND json_extract(data, '$.Email') IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
delete_csv_after_use: bool = True,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||
updated_ids = []
|
||||
|
||||
# Use IMMEDIATE to get a write lock at the start of the transaction
|
||||
with get_db_connection("IMMEDIATE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if "Id" not in row:
|
||||
logger.warning(
|
||||
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||
)
|
||||
continue
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
|
||||
# Process relationships and clean data
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
|
||||
# Remove unwanted fields
|
||||
for field in field_to_remove:
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update main object data
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(id, object_type, json.dumps(row)),
|
||||
)
|
||||
|
||||
# Update relationships using the same connection
|
||||
_update_relationship_tables(conn, id, parent_ids)
|
||||
updated_ids.append(id)
|
||||
|
||||
# If we're updating User objects, update the email map
|
||||
if object_type == "User":
|
||||
_update_user_email_map(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
if delete_csv_after_use:
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_download_path)
|
||||
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Force index usage with INDEXED BY
|
||||
cursor.execute(
|
||||
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
child_ids = {row[0] for row in cursor.fetchall()}
|
||||
return child_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if object_type is None:
|
||||
object_type = get_type_from_id(object_id)
|
||||
if not object_type:
|
||||
return None
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""Find all object IDs for rows of the specified type."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: list[str],
|
||||
parent_types: list[str],
|
||||
batch_size: int = 500,
|
||||
) -> Iterator[tuple[str, set[str]]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the
|
||||
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
|
||||
"""
|
||||
# SQLite typically has a limit of 999 variables
|
||||
updated_ids_batches = batch_list(updated_ids, batch_size)
|
||||
updated_parent_ids: set[str] = set()
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for batch_ids in updated_ids_batches:
|
||||
batch_ids = list(set(batch_ids) - updated_parent_ids)
|
||||
if not batch_ids:
|
||||
continue
|
||||
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||
|
||||
for parent_type in parent_types:
|
||||
affected_ids: set[str] = set()
|
||||
|
||||
# Get directly updated objects of parent types - using index on object_type
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT id FROM salesforce_objects
|
||||
WHERE id IN ({id_placeholders})
|
||||
AND object_type = ?
|
||||
""",
|
||||
batch_ids + [parent_type],
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Get parent objects of updated objects - using optimized relationship_types table
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT DISTINCT parent_id
|
||||
FROM relationship_types
|
||||
INDEXED BY idx_relationship_types_lookup
|
||||
WHERE parent_type = ?
|
||||
AND child_id IN ({id_placeholders})
|
||||
""",
|
||||
[parent_type] + batch_ids,
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Remove any parent IDs that have already been processed
|
||||
new_affected_ids = affected_ids - updated_parent_ids
|
||||
# Add the new affected IDs to the set of updated parent IDs
|
||||
updated_parent_ids.update(new_affected_ids)
|
||||
|
||||
if new_affected_ids:
|
||||
yield parent_type, new_affected_ids
|
||||
|
||||
|
||||
def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||
"""Check if there is at least one object of the specified type in the database.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type to check
|
||||
|
||||
Returns:
|
||||
bool: True if at least one object exists, False otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
|
||||
(object_type,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
|
||||
# NULL_ID_STRING is used to indicate that the user ID was queried but not found
|
||||
# As opposed to None because it has yet to be queried at all
|
||||
NULL_ID_STRING = "N/A"
|
||||
|
||||
|
||||
def get_user_id_by_email(email: str) -> str | None:
|
||||
"""Get the Salesforce User ID for a given email address.
|
||||
|
||||
Args:
|
||||
email: The email address to look up
|
||||
|
||||
Returns:
|
||||
A tuple of (was_found, user_id):
|
||||
- was_found: True if the email exists in the table, False if not found
|
||||
- user_id: The Salesforce User ID if exists, None otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def update_email_to_id_table(email: str, id: str | None) -> None:
|
||||
"""Update the email to ID map table with a new email and ID."""
|
||||
id_to_use = id or NULL_ID_STRING
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",
|
||||
(email, id_to_use),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -1,66 +1,72 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
id: str
|
||||
type: str
|
||||
data: dict[str, Any]
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ID": self.id,
|
||||
"Type": self.type,
|
||||
"Data": self.data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||
return cls(
|
||||
id=data["Id"],
|
||||
type=data["Type"],
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
# This defines the base path for all data files relative to this file
|
||||
# AKA BE CAREFUL WHEN MOVING THIS FILE
|
||||
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
def get_sqlite_db_path() -> str:
|
||||
"""Get the path to the sqlite db file."""
|
||||
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
os.makedirs(type_dir, exist_ok=True)
|
||||
return type_dir
|
||||
|
||||
|
||||
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
|
||||
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
|
||||
|
||||
|
||||
def validate_salesforce_id(salesforce_id: str) -> bool:
|
||||
"""Validate the checksum portion of an 18-character Salesforce ID.
|
||||
|
||||
Args:
|
||||
salesforce_id: An 18-character Salesforce ID
|
||||
|
||||
Returns:
|
||||
bool: True if the checksum is valid, False otherwise
|
||||
"""
|
||||
if len(salesforce_id) != 18:
|
||||
return False
|
||||
|
||||
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
|
||||
|
||||
checksum = salesforce_id[15:18]
|
||||
calculated_checksum = ""
|
||||
|
||||
for chunk in chunks:
|
||||
result_string = "".join(
|
||||
"1" if char.isupper() else "0" for char in reversed(chunk)
|
||||
)
|
||||
calculated_checksum += _LOOKUP[result_string]
|
||||
|
||||
return checksum == calculated_checksum
|
||||
|
||||
@@ -264,24 +264,6 @@ class SlackTextCleaner:
|
||||
message = message.replace("<!everyone>", "@everyone")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_links(message: str) -> str:
|
||||
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
|
||||
# Find user IDs in the message
|
||||
possible_link_matches = re.findall(r"<(.*?)>", message)
|
||||
for possible_link in possible_link_matches:
|
||||
if not possible_link:
|
||||
continue
|
||||
# Special slack patterns that aren't for links
|
||||
if possible_link[0] not in ["#", "@", "!"]:
|
||||
link_display = (
|
||||
possible_link
|
||||
if "|" not in possible_link
|
||||
else possible_link.split("|")[1]
|
||||
)
|
||||
message = message.replace(f"<{possible_link}>", link_display)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_catchall(message: str) -> str:
|
||||
"""Replaces pattern of <!something|another-thing> with another-thing
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
|
||||
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
|
||||
|
||||
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
|
||||
# Explicitly check if running in multi-tenant mode to prevent potential security risks
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Upload input for web connector is not supported in cloud environments"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"This is not a UI supported Web Connector flow, "
|
||||
"are you sure you want to do this?"
|
||||
|
||||
@@ -10,17 +10,21 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
@@ -40,6 +44,13 @@ class ZendeskClient:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -265,7 +276,7 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -390,6 +401,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
articles = _get_articles(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for article in articles:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"article:{article['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
elif self.content_type == "tickets":
|
||||
tickets = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for ticket in tickets:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -163,6 +164,17 @@ class SearchPipeline:
|
||||
# These chunks are ordered, deduped, and contain no large chunks
|
||||
retrieved_chunks = self._get_chunks()
|
||||
|
||||
# If ee is enabled, censor the chunk sections based on user access
|
||||
# Otherwise, return the retrieved chunks
|
||||
censored_chunks = fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
above = self.search_query.chunks_above
|
||||
below = self.search_query.chunks_below
|
||||
|
||||
@@ -175,7 +187,7 @@ class SearchPipeline:
|
||||
seen_document_ids = set()
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
seen_document_ids.add(chunk.document_id)
|
||||
chunk_requests.append(
|
||||
@@ -225,7 +237,7 @@ class SearchPipeline:
|
||||
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
|
||||
# as reranking flow may wipe the scores for a lot of the chunks.
|
||||
doc_chunk_ranges_map = defaultdict(list)
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
# The list of ranges for each document is ordered by score
|
||||
doc_chunk_ranges_map[chunk.document_id].append(
|
||||
ChunkRange(
|
||||
@@ -274,11 +286,11 @@ class SearchPipeline:
|
||||
|
||||
# In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks
|
||||
doc_chunk_ind_to_chunk.update(
|
||||
{(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks}
|
||||
{(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks}
|
||||
)
|
||||
|
||||
# Build the surroundings for all of the initial retrieved chunks
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
start_ind = max(0, chunk.chunk_id - above)
|
||||
end_ind = chunk.chunk_id + below
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user