mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
193 Commits
cloud_debu
...
user-filte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
089cfe7478 | ||
|
|
c9e0d77c93 | ||
|
|
7a750dc2ca | ||
|
|
44b70a87df | ||
|
|
a05addec19 | ||
|
|
8a4d762798 | ||
|
|
c9a420ec49 | ||
|
|
beccca5fa2 | ||
|
|
66d8b8bb10 | ||
|
|
76ca650972 | ||
|
|
eb70699c0b | ||
|
|
b401f83eb6 | ||
|
|
993a1a6caf | ||
|
|
c3481c7356 | ||
|
|
3b7695539f | ||
|
|
b1957737f2 | ||
|
|
5f462056f6 | ||
|
|
0de4d61b6d | ||
|
|
7a28a5c216 | ||
|
|
d8aa21ca3a | ||
|
|
c4323573d2 | ||
|
|
46cfaa96b7 | ||
|
|
a610b6bd8d | ||
|
|
cb66aadd80 | ||
|
|
9ea2ae267e | ||
|
|
7d86b28335 | ||
|
|
4f8e48df7c | ||
|
|
d96d2fc6e9 | ||
|
|
b6dd999c1b | ||
|
|
9a09222b7d | ||
|
|
be3cfdd4a6 | ||
|
|
f5bdf9d2c9 | ||
|
|
6afd27f9c9 | ||
|
|
ccef350287 | ||
|
|
4400a945e3 | ||
|
|
384a38418b | ||
|
|
2163a138ed | ||
|
|
b6c2ecfecb | ||
|
|
ac182c74b3 | ||
|
|
cab7e60542 | ||
|
|
8e25c3c412 | ||
|
|
1470b7e038 | ||
|
|
bf78fb79f8 | ||
|
|
d972a78f45 | ||
|
|
962240031f | ||
|
|
50131ba22c | ||
|
|
439217317f | ||
|
|
c55de28423 | ||
|
|
91e32e801d | ||
|
|
2ae91f0f2b | ||
|
|
d40fd82803 | ||
|
|
97a963b4bf | ||
|
|
7f6ef1ff57 | ||
|
|
d98746b988 | ||
|
|
a76f1b4c1b | ||
|
|
4c4ff46fe3 | ||
|
|
0f9842064f | ||
|
|
d7bc32c0ec | ||
|
|
1f48de9731 | ||
|
|
a22d02ff70 | ||
|
|
dcfc621a66 | ||
|
|
eac73a1bf1 | ||
|
|
717560872f | ||
|
|
ce2572134c | ||
|
|
02f72a5c86 | ||
|
|
eb916df139 | ||
|
|
fafad5e119 | ||
|
|
a314a08309 | ||
|
|
4ce24d68f7 | ||
|
|
a95f4298ad | ||
|
|
7cd76ec404 | ||
|
|
5b5c1166ca | ||
|
|
d9e9c6973d | ||
|
|
91903141cd | ||
|
|
e329b63b89 | ||
|
|
71c2559ea9 | ||
|
|
ceb34a41d9 | ||
|
|
82eab9d704 | ||
|
|
2b8d3a6ef5 | ||
|
|
4fb129e77b | ||
|
|
f16ca1b735 | ||
|
|
e3b2c9d944 | ||
|
|
6c9c25642d | ||
|
|
2862d8bbd3 | ||
|
|
143be6a524 | ||
|
|
c2444a5cff | ||
|
|
7f8194798a | ||
|
|
e3947e4b64 | ||
|
|
98005510ad | ||
|
|
ca54bd0b21 | ||
|
|
d26f8ce852 | ||
|
|
c8090ab75b | ||
|
|
e100a5e965 | ||
|
|
ddec239fef | ||
|
|
e83542f572 | ||
|
|
8750f14647 | ||
|
|
27699c8216 | ||
|
|
6fcd712a00 | ||
|
|
b027a08698 | ||
|
|
1db778baa8 | ||
|
|
f895e5f7d0 | ||
|
|
2fc58252f4 | ||
|
|
371d1ccd8f | ||
|
|
7fb92d42a0 | ||
|
|
af2061c4db | ||
|
|
ffec19645b | ||
|
|
67d2c86250 | ||
|
|
6c018cb53f | ||
|
|
62302e3faf | ||
|
|
0460531c72 | ||
|
|
6af07a888b | ||
|
|
ea75f5cd5d | ||
|
|
b92c183022 | ||
|
|
c191e23256 | ||
|
|
66f9124135 | ||
|
|
8f0fb70bbf | ||
|
|
ef5e5c80bb | ||
|
|
03acb6587a | ||
|
|
d1ec72b5e5 | ||
|
|
3b214133a8 | ||
|
|
2232702e99 | ||
|
|
8108ff0a4b | ||
|
|
f64e78e986 | ||
|
|
08312a4394 | ||
|
|
92add655e0 | ||
|
|
d64464ca7c | ||
|
|
ccd3983802 | ||
|
|
240f3e4fff | ||
|
|
1291b3d930 | ||
|
|
d05f1997b5 | ||
|
|
aa2e2a62b9 | ||
|
|
174e5968f8 | ||
|
|
1f27606e17 | ||
|
|
60355b84c1 | ||
|
|
680ab9ea30 | ||
|
|
c2447dbb1c | ||
|
|
52bad522f8 | ||
|
|
63e5e58313 | ||
|
|
2643782e30 | ||
|
|
3eb72e5c1d | ||
|
|
9b65c23a7e | ||
|
|
b43a8e48c6 | ||
|
|
1955c1d67b | ||
|
|
3f92ed9d29 | ||
|
|
618369f4a1 | ||
|
|
2783216781 | ||
|
|
bec0f9fb23 | ||
|
|
97a03e7fc8 | ||
|
|
8d6e8269b7 | ||
|
|
9ce2c6c517 | ||
|
|
2ad8bdbc65 | ||
|
|
a83c9b40d5 | ||
|
|
340fab1375 | ||
|
|
3ec338307f | ||
|
|
27acd3387a | ||
|
|
d14ef431a7 | ||
|
|
9bffeb65af | ||
|
|
f4806da653 | ||
|
|
e2700b2bbd | ||
|
|
fc81a3fb12 | ||
|
|
2203cfabea | ||
|
|
f4050306d6 | ||
|
|
2d960a477f | ||
|
|
8837b8ea79 | ||
|
|
3dfb214f73 | ||
|
|
18d7262608 | ||
|
|
09b879ee73 | ||
|
|
aaa668c963 | ||
|
|
edb877f4bc | ||
|
|
eb369caefb | ||
|
|
b9567eabd7 | ||
|
|
13bbf67091 | ||
|
|
457a4c73f0 | ||
|
|
ce37688b5b | ||
|
|
4e2c90f4af | ||
|
|
513dd8a319 | ||
|
|
71c5043832 | ||
|
|
64b6f15e95 | ||
|
|
35022f5f09 | ||
|
|
0d44014c16 | ||
|
|
1b9e9f48fa | ||
|
|
05fb5aa27b | ||
|
|
3b645b72a3 | ||
|
|
fe770b5c3a | ||
|
|
1eaf885f50 | ||
|
|
a187aa508c | ||
|
|
aa4bfa2a78 | ||
|
|
9011b8a139 | ||
|
|
59c774353a | ||
|
|
b458d504af | ||
|
|
f83e7bfcd9 | ||
|
|
4d2e26ce4b | ||
|
|
817fdc1f36 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,24 +6,6 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,6 +66,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -118,6 +118,6 @@ jobs:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
14
.github/workflows/pr-python-connector-tests.yml
vendored
14
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,7 +26,19 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
2
.vscode/env_template.txt
vendored
2
.vscode/env_template.txt
vendored
@@ -5,6 +5,8 @@
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
# Skip warm up for dev
|
||||
SKIP_WARM_UP=True
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
|
||||
44
.vscode/launch.template.jsonc
vendored
44
.vscode/launch.template.jsonc
vendored
@@ -28,6 +28,7 @@
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -51,7 +52,8 @@
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@@ -269,6 +271,31 @@
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
@@ -355,5 +382,20 @@
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -12,6 +12,10 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to Hagen (or any other maintainer) on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
@@ -23,8 +27,8 @@ If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
And always feel free to message us (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ) /
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
@@ -42,7 +46,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
|
||||
@@ -123,7 +127,47 @@ Once the above is done, navigate to `onyx/web` run:
|
||||
npm i
|
||||
```
|
||||
|
||||
#### Docker containers for external software
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
We highly recommend using VSCode debugger for development.
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
@@ -135,7 +179,7 @@ docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
@@ -223,35 +267,6 @@ If you want to make changes to Onyx and run those changes in Docker, you can als
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d --build
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
|
||||
#### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
### Release Process
|
||||
|
||||
|
||||
29
CONTRIBUTING_VSCODE.md
Normal file
29
CONTRIBUTING_VSCODE.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# VSCode Debugging Setup
|
||||
|
||||
This guide explains how to set up and use VSCode's debugging capabilities with this project.
|
||||
|
||||
## Initial Setup
|
||||
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/.env.template` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
|
||||
## Using the Debugger
|
||||
|
||||
Before starting, make sure the Docker Daemon is running.
|
||||
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
- Python debugging is configured with debugpy
|
||||
- Environment variables are loaded from `.vscode/.env`
|
||||
- Console output is organized in the integrated terminal with labeled tabs
|
||||
18
README.md
18
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -9,3 +9,4 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
region = AWS_REGION_NAME
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add shortcut option for users
|
||||
|
||||
Revision ID: 027381bce97c
|
||||
Revises: 6fc7886d665d
|
||||
Create Date: 2025-01-14 12:14:00.814390
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "027381bce97c"
|
||||
down_revision = "6fc7886d665d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "shortcut_enabled")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add index to index_attempt.time_created
|
||||
|
||||
Revision ID: 0f7ff6d75b57
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 14:01:14.067144
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0f7ff6d75b57"
|
||||
down_revision = "fec3db967bf7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_status"),
|
||||
"index_attempt",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_index_attempt_time_created"),
|
||||
"index_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
|
||||
|
||||
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")
|
||||
@@ -0,0 +1,24 @@
|
||||
"""add chunk count to document
|
||||
|
||||
Revision ID: 2955778aa44c
|
||||
Revises: c0aab6edb6dd
|
||||
Create Date: 2025-01-04 11:39:43.268612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2955778aa44c"
|
||||
down_revision = "c0aab6edb6dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "chunk_count")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""add composite index for index attempt time updated
|
||||
|
||||
Revision ID: 369644546676
|
||||
Revises: 2955778aa44c
|
||||
Create Date: 2025-01-08 15:38:17.224380
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "369644546676"
|
||||
down_revision = "2955778aa44c"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_index_attempt_ccpair_search_settings_time_updated",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
text("time_updated DESC"),
|
||||
],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_index_attempt_ccpair_search_settings_time_updated",
|
||||
table_name="index_attempt",
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""add back input prompts
|
||||
|
||||
Revision ID: 3c6531f32351
|
||||
Revises: aeda5f2df4f6
|
||||
Create Date: 2025-01-13 12:49:51.705235
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c6531f32351"
|
||||
down_revision = "aeda5f2df4f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -40,6 +40,6 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
|
||||
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
|
||||
op.drop_column("persona", "category_id")
|
||||
op.drop_table("persona_category")
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""make categories labels and many to many
|
||||
|
||||
Revision ID: 6fc7886d665d
|
||||
Revises: 3c6531f32351
|
||||
Create Date: 2025-01-13 18:12:18.029112
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6fc7886d665d"
|
||||
down_revision = "3c6531f32351"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename persona_category table to persona_label
|
||||
op.rename_table("persona_category", "persona_label")
|
||||
|
||||
# Create the new association table
|
||||
op.create_table(
|
||||
"persona__persona_label",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("persona_label_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_label_id"],
|
||||
["persona_label.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
|
||||
)
|
||||
|
||||
# Copy existing relationships to the new table
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO persona__persona_label (persona_id, persona_label_id)
|
||||
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Remove the old category_id column from persona table
|
||||
op.drop_column("persona", "category_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename persona_label table back to persona_category
|
||||
op.rename_table("persona_label", "persona_category")
|
||||
|
||||
# Add back the category_id column to persona table
|
||||
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"persona_category_id_fkey",
|
||||
"persona",
|
||||
"persona_category",
|
||||
["category_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Copy the first label relationship back to the persona table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET category_id = (
|
||||
SELECT persona_label_id
|
||||
FROM persona__persona_label
|
||||
WHERE persona__persona_label.persona_id = persona.id
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the association table
|
||||
op.drop_table("persona__persona_label")
|
||||
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Add SyncRecord
|
||||
|
||||
Revision ID: 97dbb53fa8c8
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-11 19:39:50.426302
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "97dbb53fa8c8"
|
||||
down_revision = "be2ab2aa50ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"sync_record",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"sync_type",
|
||||
sa.Enum(
|
||||
"DOCUMENT_SET",
|
||||
"USER_GROUP",
|
||||
"CONNECTOR_DELETION",
|
||||
name="synctype",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"sync_status",
|
||||
sa.Enum(
|
||||
"IN_PROGRESS",
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="syncstatus",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
|
||||
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add index for fetch_latest_sync_record query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_start_time",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_start_time"],
|
||||
)
|
||||
|
||||
# Add index for cleanup_sync_records query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_status",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
|
||||
op.drop_table("sync_record")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add pinned assistants
|
||||
|
||||
Revision ID: aeda5f2df4f6
|
||||
Revises: c5eae4a75a1b
|
||||
Create Date: 2025-01-09 16:04:10.770636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aeda5f2df4f6"
|
||||
down_revision = "c5eae4a75a1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "pinned_assistants")
|
||||
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""fix_capitalization
|
||||
|
||||
Revision ID: be2ab2aa50ee
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 13:13:26.228960
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "be2ab2aa50ee"
|
||||
down_revision = "369644546676"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET
|
||||
external_user_group_ids = ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
),
|
||||
last_modified = NOW()
|
||||
WHERE
|
||||
external_user_group_ids IS NOT NULL
|
||||
AND external_user_group_ids::text[] <> ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
)::text[]
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No way to cleanly persist the bad state through an upgrade/downgrade
|
||||
# cycle, so we just pass
|
||||
pass
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Add chat_message__standard_answer table
|
||||
|
||||
Revision ID: c5eae4a75a1b
|
||||
Revises: 0f7ff6d75b57
|
||||
Create Date: 2025-01-15 14:08:49.688998
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c5eae4a75a1b"
|
||||
down_revision = "0f7ff6d75b57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_message__standard_answer",
|
||||
sa.Column("chat_message_id", sa.Integer(), nullable=False),
|
||||
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_message_id"],
|
||||
["chat_message.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["standard_answer_id"],
|
||||
["standard_answer.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_message__standard_answer")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add time_updated to UserGroup and DocumentSet
|
||||
|
||||
Revision ID: fec3db967bf7
|
||||
Revises: 97dbb53fa8c8
|
||||
Create Date: 2025-01-12 15:49:02.289100
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fec3db967bf7"
|
||||
down_revision = "97dbb53fa8c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document_set",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_group", "time_last_modified_by_user")
|
||||
op.drop_column("document_set", "time_last_modified_by_user")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""mapping for anonymous user path
|
||||
|
||||
Revision ID: a4f6ee863c47
|
||||
Revises: 14a83a331951
|
||||
Create Date: 2025-01-04 14:16:58.697451
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a4f6ee863c47"
|
||||
down_revision = "14a83a331951"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tenant_anonymous_user_path",
|
||||
sa.Column("tenant_id", sa.String(), primary_key=True, nullable=False),
|
||||
sa.Column("anonymous_user_path", sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
sa.UniqueConstraint("anonymous_user_path"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tenant_anonymous_user_path")
|
||||
@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -52,9 +57,20 @@ def _get_access_for_documents(
|
||||
)
|
||||
doc_id_map = {doc.id: doc for doc in documents}
|
||||
|
||||
# Get all sources in one batch
|
||||
doc_id_to_source_map = get_document_sources(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
set(document.external_user_emails)
|
||||
@@ -70,7 +86,11 @@ def _get_access_for_documents(
|
||||
|
||||
# If the document is determined to be "public" externally (through a SYNC connector)
|
||||
# then it's given the same access level as if it were marked public within Onyx
|
||||
is_public_anywhere = document.is_public or non_ee_access.is_public
|
||||
# If its censored, then it's public anywhere during the search and then permissions are
|
||||
# applied after the search
|
||||
is_public_anywhere = (
|
||||
document.is_public or non_ee_access.is_public or is_only_censored
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.seeding import get_seed_config
|
||||
from ee.onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -118,3 +121,17 @@ async def current_cloud_superuser(
|
||||
detail="Access denied. User must be a cloud superuser to perform this action.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
# Token does not expire
|
||||
"iat": datetime.utcnow(), # Issued at time
|
||||
}
|
||||
|
||||
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_anonymous_user_jwt_token(token: str) -> dict:
|
||||
return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"])
|
||||
|
||||
@@ -6,7 +6,11 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.user_group import delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -42,15 +46,59 @@ def monitor_usergroup_taskset(
|
||||
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
usergroup_name = user_group.name
|
||||
try:
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
raise e
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -15,6 +15,12 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# This is a boolean that determines if anonymous access is public
|
||||
# Default behavior is to not make the page public and instead add a group
|
||||
# that contains all the users that we found in Confluence
|
||||
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -55,3 +61,5 @@ POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Date
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
def fetch_query_analytics(
|
||||
@@ -234,3 +238,122 @@ def fetch_persona_unique_users(
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_message_analytics(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily message counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(ChatMessage.id),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily unique user counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(func.distinct(ChatSession.user_id)),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users_total(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the total number of distinct users who have sent or received messages from
|
||||
the specified assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(func.count(func.distinct(ChatSession.user_id)))
|
||||
.select_from(ChatMessage)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(query).scalar()
|
||||
return result if result else 0
|
||||
|
||||
|
||||
# Users can view assistant stats if they created the persona,
|
||||
# or if they are an admin
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
stmt = select(Persona).where(
|
||||
and_(Persona.id == assistant_id, Persona.user_id == user.id)
|
||||
)
|
||||
|
||||
persona = db_session.execute(stmt).scalar_one_or_none()
|
||||
return persona is not None
|
||||
|
||||
@@ -5,7 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Document as DbDocument
|
||||
|
||||
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups = [
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
|
||||
@@ -6,10 +6,12 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -60,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
# NOTE: this function handles case sensitivity for emails
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
@@ -83,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
@@ -106,3 +112,21 @@ def fetch_external_groups_for_user(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
group_ids: list[str],
|
||||
) -> list[User__ExternalUserGroupId]:
|
||||
user = get_user_by_email(db_session=db_session, email=user_email)
|
||||
if user is None:
|
||||
return []
|
||||
user_id = user.id
|
||||
user_ext_groups = db_session.scalars(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
|
||||
)
|
||||
).all()
|
||||
return list(user_ext_groups)
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
from onyx.db.models import TokenRateLimit
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -20,10 +21,11 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
|
||||
stmt = stmt.distinct()
|
||||
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
|
||||
@@ -46,6 +48,12 @@ def _add_user_filters(
|
||||
that the user isn't a curator for
|
||||
- if we are not editing, we show all token_rate_limits in the groups the user curates
|
||||
"""
|
||||
|
||||
# If user is None, this is an anonymous user and we should only show public token_rate_limits
|
||||
if user is None:
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
@@ -103,10 +111,10 @@ def insert_user_group_token_rate_limit(
|
||||
return token_limit
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits(
|
||||
def fetch_user_group_token_rate_limits_for_user(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User | None = None,
|
||||
user: User | None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -374,7 +374,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(name=user_group.name)
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
@@ -440,32 +442,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=set_curator_request.user_id,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -486,7 +564,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -554,6 +632,10 @@ def update_user_group(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -623,7 +705,10 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
connector_credential_pair_id matches the given cc_pair_id.
|
||||
|
||||
Should be used very carefully (only for connectors that are being deleted)."""
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# This is a group that we use to store all the users that we found in Confluence
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
@@ -4,6 +4,8 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
@@ -22,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
@@ -31,14 +35,32 @@ def _get_server_space_permissions(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
@@ -47,14 +69,17 @@ def _get_server_space_permissions(
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
# TODO: Check if the space is publicly accessible
|
||||
# Currently, we assume the space is not public
|
||||
# We need to check if anonymous access is turned on for the site and space
|
||||
# This information is paywalled so it remains unimplemented
|
||||
is_public=False,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +159,7 @@ def _get_space_permissions(
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> ExternalAccess | None:
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
@@ -177,21 +202,57 @@ def _extract_read_access_restrictions(
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return set(read_access_user_emails), set(read_access_group_names)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page by taking the intersection
|
||||
of the page's restrictions and the restrictions of all the ancestors
|
||||
of the page.
|
||||
If the page/ancestor has no restrictions, then it is ignored (no intersection).
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
found_user_emails, found_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
for ancestor in ancestors:
|
||||
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if not ancestor_user_emails and not ancestor_group_names:
|
||||
# This ancestor has no restrictions, so it has no effect on
|
||||
# the page's restrictions, so we ignore it
|
||||
continue
|
||||
|
||||
found_user_emails.intersection_update(ancestor_user_emails)
|
||||
found_group_names.intersection_update(ancestor_group_names)
|
||||
|
||||
# If there are no restrictions found, then the page
|
||||
# inherits the space's restrictions so return None
|
||||
is_space_public = read_access_user_emails == [] and read_access_group_names == []
|
||||
if is_space_public:
|
||||
if not found_user_emails and not found_group_names:
|
||||
return None
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_group_names),
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
# there is no way for a page to be individually public if the space isn't public
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
@@ -208,11 +269,11 @@ def _fetch_all_page_restrictions_for_space(
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
restrictions = _extract_read_access_restrictions(
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
@@ -301,7 +362,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
return _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def _build_group_member_email_map(
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
@@ -53,6 +54,7 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
@@ -60,5 +62,15 @@ def confluence_group_sync(
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
all_found_emails.update(group_member_emails)
|
||||
|
||||
# This is so that when we find a public confleunce server page, we can
|
||||
# give access to all users only in if they have an email in Confluence
|
||||
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
|
||||
all_found_group = ExternalUserGroup(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
}
|
||||
|
||||
|
||||
# NOTE: This is only called if ee is enabled.
|
||||
def _post_query_chunk_censoring(
|
||||
chunks: list[InferenceChunk],
|
||||
user: User | None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function checks all chunks to see if they need to be sent to a censoring
|
||||
function. If they do, it sends them to the censoring function and returns the
|
||||
censored chunks. If they don't, it returns the original chunks.
|
||||
"""
|
||||
if user is None:
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
for chunk in chunks:
|
||||
# Separate out chunks that require permission post-processing by source
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to censor chunks for source {source} so throwing out all"
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
return chunks_to_keep
|
||||
@@ -0,0 +1,226 @@
|
||||
import time
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_any_salesforce_client_for_doc_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_salesforce_user_id_from_email,
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Types
|
||||
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
|
||||
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
|
||||
|
||||
|
||||
# NOTE: Used for testing timing
|
||||
def _get_dummy_object_access_map(
|
||||
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
|
||||
) -> dict[str, bool]:
|
||||
time.sleep(0.15)
|
||||
# return {object_id: True for object_id in object_ids}
|
||||
import random
|
||||
|
||||
return {object_id: random.choice([True, False]) for object_id in object_ids}
|
||||
|
||||
|
||||
def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids: set[str],
|
||||
user_email: str,
|
||||
chunks: list[InferenceChunk],
|
||||
) -> dict[str, bool] | None:
|
||||
"""
|
||||
This function wraps the salesforce call as we may want to change how this
|
||||
is done in the future. (E.g. replace it with the above function)
|
||||
"""
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries for this source are essentially instant
|
||||
first_doc_id = chunks[0].document_id
|
||||
with get_session_context_manager() as db_session:
|
||||
salesforce_client = get_any_salesforce_client_for_doc_id(
|
||||
db_session, first_doc_id
|
||||
)
|
||||
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries by the same user are essentially instant
|
||||
start_time = time.time()
|
||||
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
# so it takes 0.1-0.2 seconds total
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
def _extract_salesforce_object_id_from_url(url: str) -> str:
|
||||
return url.split("/")[-1]
|
||||
|
||||
|
||||
def _get_object_ranges_for_chunk(
|
||||
chunk: InferenceChunk,
|
||||
) -> dict[str, list[ContentRange]]:
|
||||
"""
|
||||
Given a chunk, return a dictionary of salesforce object ids and the content ranges
|
||||
for that object id in the current chunk
|
||||
"""
|
||||
if chunk.source_links is None:
|
||||
return {}
|
||||
|
||||
object_ranges: dict[str, list[ContentRange]] = {}
|
||||
end_index = None
|
||||
descending_source_links = sorted(
|
||||
chunk.source_links.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
for start_index, url in descending_source_links:
|
||||
object_id = _extract_salesforce_object_id_from_url(url)
|
||||
if object_id not in object_ranges:
|
||||
object_ranges[object_id] = []
|
||||
object_ranges[object_id].append((start_index, end_index))
|
||||
end_index = start_index
|
||||
return object_ranges
|
||||
|
||||
|
||||
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
|
||||
"""
|
||||
Create a copy of the unfiltered chunk where potentially sensitive content is removed
|
||||
to be added later if the user has access to each of the sub-objects
|
||||
"""
|
||||
empty_censored_chunk = InferenceChunk(
|
||||
**uncensored_chunk.model_dump(),
|
||||
)
|
||||
empty_censored_chunk.content = ""
|
||||
empty_censored_chunk.blurb = ""
|
||||
empty_censored_chunk.source_links = {}
|
||||
return empty_censored_chunk
|
||||
|
||||
|
||||
def _update_censored_chunk(
|
||||
censored_chunk: InferenceChunk,
|
||||
uncensored_chunk: InferenceChunk,
|
||||
content_range: ContentRange,
|
||||
) -> InferenceChunk:
|
||||
"""
|
||||
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
|
||||
"""
|
||||
start_index, end_index = content_range
|
||||
|
||||
# Update the content of the filtered chunk
|
||||
permitted_content = uncensored_chunk.content[start_index:end_index]
|
||||
permitted_section_start_index = len(censored_chunk.content)
|
||||
censored_chunk.content = permitted_content + censored_chunk.content
|
||||
|
||||
# Update the source links of the filtered chunk
|
||||
if uncensored_chunk.source_links is not None:
|
||||
if censored_chunk.source_links is None:
|
||||
censored_chunk.source_links = {}
|
||||
link_content = uncensored_chunk.source_links[start_index]
|
||||
censored_chunk.source_links[permitted_section_start_index] = link_content
|
||||
|
||||
# Update the blurb of the filtered chunk
|
||||
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
|
||||
|
||||
return censored_chunk
|
||||
|
||||
|
||||
# TODO: Generalize this to other sources
|
||||
def censor_salesforce_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
user_email: str,
|
||||
# This is so we can provide a mock access map for testing
|
||||
access_map: dict[str, bool] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
|
||||
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
|
||||
|
||||
# (doc_id, chunk_id) -> chunk
|
||||
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
|
||||
# keep track of all object ids that we have seen to make it easier to get
|
||||
# the access for these object ids
|
||||
object_ids: set[str] = set()
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_key = (chunk.document_id, chunk.chunk_id)
|
||||
# create a dictionary to quickly look up the unfiltered chunk
|
||||
uncensored_chunks[chunk_key] = chunk
|
||||
|
||||
# for each chunk, get a dictionary of object ids and the content ranges
|
||||
# for that object id in the current chunk
|
||||
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
|
||||
for object_id, ranges in object_ranges_for_chunk.items():
|
||||
object_ids.add(object_id)
|
||||
for start_index, end_index in ranges:
|
||||
object_to_content_map.setdefault(object_id, []).append(
|
||||
(chunk_key, (start_index, end_index))
|
||||
)
|
||||
|
||||
# This is so we can provide a mock access map for testing
|
||||
if access_map is None:
|
||||
access_map = _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids=object_ids,
|
||||
user_email=user_email,
|
||||
chunks=chunks,
|
||||
)
|
||||
if access_map is None:
|
||||
# If the user is not found in Salesforce, access_map will be None
|
||||
# so we should just return an empty list because no chunks will be
|
||||
# censored
|
||||
return []
|
||||
|
||||
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
for object_id, content_list in object_to_content_map.items():
|
||||
# if the user does not have access to the object, or the object is not in the
|
||||
# access_map, do not include its content in the filtered chunks
|
||||
if not access_map.get(object_id, False):
|
||||
continue
|
||||
|
||||
# if we got this far, the user has access to the object so we can create or update
|
||||
# the filtered chunk(s) for this object
|
||||
# NOTE: we only create a censored chunk if the user has access to some
|
||||
# part of the chunk
|
||||
for chunk_key, content_range in content_list:
|
||||
if chunk_key not in censored_chunks:
|
||||
censored_chunks[chunk_key] = _create_empty_censored_chunk(
|
||||
uncensored_chunks[chunk_key]
|
||||
)
|
||||
|
||||
uncensored_chunk = uncensored_chunks[chunk_key]
|
||||
censored_chunk = _update_censored_chunk(
|
||||
censored_chunk=censored_chunks[chunk_key],
|
||||
uncensored_chunk=uncensored_chunk,
|
||||
content_range=content_range,
|
||||
)
|
||||
censored_chunks[chunk_key] = censored_chunk
|
||||
|
||||
return list(censored_chunks.values())
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_objects_access_for_user_email(
|
||||
object_ids: set[str], user_email: str
|
||||
) -> dict[str, bool]:
|
||||
with get_session_context_manager() as db_session:
|
||||
external_groups = fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session=db_session,
|
||||
user_email=user_email,
|
||||
# Maybe make a function that adds a salesforce prefix to the group ids
|
||||
group_ids=list(object_ids),
|
||||
)
|
||||
external_group_ids = {group.external_user_group_id for group in external_groups}
|
||||
return {group_id: group_id in external_group_ids for group_id in object_ids}
|
||||
177
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
177
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from simple_salesforce import Salesforce
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_cc_pairs_for_document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
|
||||
|
||||
|
||||
def get_any_salesforce_client_for_doc_id(
|
||||
db_session: Session, doc_id: str
|
||||
) -> Salesforce:
|
||||
"""
|
||||
We create a salesforce client for the first cc_pair for the first doc_id where
|
||||
salesforce censoring is enabled. After that we just cache and reuse the same
|
||||
client for all queries.
|
||||
|
||||
We do this to reduce the number of postgres queries we make at query time.
|
||||
|
||||
This may be problematic if they are using multiple cc_pairs for salesforce.
|
||||
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
|
||||
but only one has the permissions to access the permissions needed for the query.
|
||||
"""
|
||||
global _ANY_SALESFORCE_CLIENT
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
return _ANY_SALESFORCE_CLIENT
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
# If we don't know their user_id, we don't store anything in the cache.
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_salesforce_user_id_from_email(
|
||||
sf_client: Salesforce,
|
||||
user_email: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
We cache this so we don't have to query Salesforce for every query and salesforce
|
||||
user IDs never change.
|
||||
Memory usage is fine because we just store 2 small strings per user.
|
||||
|
||||
If the email is not in the cache, we check the local salesforce database for the info.
|
||||
If the user is not found in the local salesforce database, we query Salesforce.
|
||||
Whatever we get back from Salesforce is added to the database.
|
||||
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
|
||||
we don't query Salesforce again (which is slow) but we still check the local salesforce
|
||||
database every query until a user id is found. This is acceptable because the query time
|
||||
is quite fast.
|
||||
If a user_id is created in Salesforce, it will be added to the local salesforce database
|
||||
next time the connector is run. Then that value will be found in this function and cached.
|
||||
|
||||
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
|
||||
salesforce database. (Around 0.1-0.3 seconds)
|
||||
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
|
||||
"""
|
||||
global _CACHED_SF_EMAIL_TO_ID_MAP
|
||||
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
|
||||
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
|
||||
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
|
||||
|
||||
db_exists = True
|
||||
try:
|
||||
# Check if the user is already in the database
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception:
|
||||
init_db()
|
||||
try:
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if user is in database: {e}")
|
||||
user_id = None
|
||||
db_exists = False
|
||||
|
||||
# If no entry is found in the database (indicated by user_id being None)...
|
||||
if user_id is None:
|
||||
# ...query Salesforce and store the result in the database
|
||||
user_id = _query_salesforce_user_id(sf_client, user_email)
|
||||
if db_exists:
|
||||
update_email_to_id_table(user_email, user_id)
|
||||
return user_id
|
||||
elif user_id is None:
|
||||
return None
|
||||
elif user_id == NULL_ID_STRING:
|
||||
return None
|
||||
# If the found user_id is real, cache it
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
|
||||
return user_id
|
||||
|
||||
|
||||
_MAX_RECORD_IDS_PER_QUERY = 200
|
||||
|
||||
|
||||
def get_objects_access_for_user_id(
|
||||
salesforce_client: Salesforce,
|
||||
user_id: str,
|
||||
record_ids: list[str],
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Salesforce has a limit of 200 record ids per query. So we just truncate
|
||||
the list of record ids to 200. We only ever retrieve 50 chunks at a time
|
||||
so this should be fine (unlikely that we retrieve all 50 chunks contain
|
||||
4 unique objects).
|
||||
If we decide this isn't acceptable we can use multiple queries but they
|
||||
should be in parallel so query time doesn't get too long.
|
||||
"""
|
||||
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
|
||||
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
|
||||
access_query = f"""
|
||||
SELECT RecordId, HasReadAccess
|
||||
FROM UserRecordAccess
|
||||
WHERE RecordId IN ({record_ids_str})
|
||||
AND UserId = '{user_id}'
|
||||
"""
|
||||
result = salesforce_client.query_all(access_query)
|
||||
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
|
||||
|
||||
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
|
||||
"""
|
||||
Uses a document id to get the cc_pair that indexed that document and uses the credentials
|
||||
for that cc_pair to create a Salesforce client.
|
||||
Problems:
|
||||
- There may be multiple cc_pairs for a document, and we don't know which one to use.
|
||||
- right now we just use the first one
|
||||
- Building a new Salesforce client for each document is slow.
|
||||
- Memory usage could be an issue as we build these dictionaries.
|
||||
"""
|
||||
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
|
||||
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]
|
||||
@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -97,19 +96,20 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -150,9 +150,9 @@ def _handle_standard_answers(
|
||||
db_session=db_session,
|
||||
description="",
|
||||
user_id=None,
|
||||
persona_id=slack_channel_config.persona.id
|
||||
if slack_channel_config.persona
|
||||
else 0,
|
||||
persona_id=(
|
||||
slack_channel_config.persona.id if slack_channel_config.persona else 0
|
||||
),
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -182,7 +182,7 @@ def _handle_standard_answers(
|
||||
formatted_answers.append(formatted_answer)
|
||||
answer_message = "\n\n".join(formatted_answers)
|
||||
|
||||
_ = create_new_chat_message(
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
@@ -191,8 +191,13 @@ def _handle_standard_answers(
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
commit=False,
|
||||
)
|
||||
# attach the standard answers to the chat message
|
||||
chat_message.standard_answers = [
|
||||
standard_answer for standard_answer, _ in matching_standard_answers
|
||||
]
|
||||
db_session.commit()
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.analytics import fetch_assistant_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
|
||||
from ee.onyx.db.analytics import fetch_onyxbot_analytics
|
||||
from ee.onyx.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -191,3 +198,74 @@ def get_persona_unique_users(
|
||||
)
|
||||
)
|
||||
return unique_user_counts
|
||||
|
||||
|
||||
class AssistantDailyUsageResponse(BaseModel):
|
||||
date: datetime.date
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
class AssistantStatsResponse(BaseModel):
|
||||
daily_stats: List[AssistantDailyUsageResponse]
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
@router.get("/assistant/{assistant_id}/stats")
|
||||
def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
Returns daily message and unique user counts for a user's assistant,
|
||||
along with the overall total messages and total distinct users.
|
||||
"""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
|
||||
if not user_can_view_assistant_stats(db_session, user, assistant_id):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not allowed to access this assistant's stats."
|
||||
)
|
||||
|
||||
# Pull daily usage from the DB calls
|
||||
messages_data = fetch_assistant_message_analytics(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
unique_users_data = fetch_assistant_unique_users(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
# Map each day => (messages, unique_users).
|
||||
daily_messages_map = {date: count for count, date in messages_data}
|
||||
daily_unique_users_map = {date: count for count, date in unique_users_data}
|
||||
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
|
||||
|
||||
# Merge both sets of metrics by date
|
||||
daily_results: list[AssistantDailyUsageResponse] = []
|
||||
for date in sorted(all_dates):
|
||||
daily_results.append(
|
||||
AssistantDailyUsageResponse(
|
||||
date=date,
|
||||
total_messages=daily_messages_map.get(date, 0),
|
||||
total_unique_users=daily_unique_users_map.get(date, 0),
|
||||
)
|
||||
)
|
||||
|
||||
# Now pull a single total distinct user count across the entire time range
|
||||
total_msgs = sum(d.total_messages for d in daily_results)
|
||||
total_users = fetch_assistant_unique_users_total(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
return AssistantStatsResponse(
|
||||
daily_stats=daily_results,
|
||||
total_messages=total_msgs,
|
||||
total_unique_users=total_users,
|
||||
)
|
||||
|
||||
@@ -2,15 +2,16 @@ import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -22,11 +23,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
tenant_id = (
|
||||
_get_tenant_id_from_request(request, logger)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
@@ -35,27 +36,46 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
raise
|
||||
|
||||
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
async def _get_tenant_id_from_request(
|
||||
request: Request, logger: logging.LoggerAdapter
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
@@ -67,9 +87,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -62,14 +82,7 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -77,10 +90,14 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -102,82 +119,151 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -192,8 +278,11 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -203,6 +292,11 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -210,8 +304,6 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -223,6 +315,7 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -13,9 +13,8 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
from ee.onyx.db.usage_export import write_usage_report
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
|
||||
from onyx.auth.schemas import UserStatus
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.users import list_users
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -84,15 +83,15 @@ def generate_user_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["user_id", "status"])
|
||||
csvwriter.writerow(["user_id", "is_active"])
|
||||
|
||||
users = list_users(db_session)
|
||||
users = get_all_users(db_session)
|
||||
for user in users:
|
||||
user_skeleton = UserSkeleton(
|
||||
user_id=str(user.id),
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
|
||||
|
||||
temp_file.seek(0)
|
||||
file_store.save_file(
|
||||
|
||||
@@ -4,8 +4,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserStatus
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -22,7 +20,7 @@ class ChatMessageSkeleton(BaseModel):
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
user_id: str
|
||||
status: UserStatus
|
||||
is_active: bool
|
||||
|
||||
|
||||
class UsageReportMetadata(BaseModel):
|
||||
|
||||
59
backend/ee/onyx/server/tenants/anonymous_user_path.py
Normal file
59
backend/ee/onyx/server/tenants/anonymous_user_path.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import TenantAnonymousUserPath
|
||||
|
||||
|
||||
def get_anonymous_user_path(tenant_id: str, db_session: Session) -> str | None:
|
||||
result = db_session.execute(
|
||||
select(TenantAnonymousUserPath).where(
|
||||
TenantAnonymousUserPath.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
result_scalar = result.scalar_one_or_none()
|
||||
if result_scalar:
|
||||
return result_scalar.anonymous_user_path
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def modify_anonymous_user_path(
|
||||
tenant_id: str, anonymous_user_path: str, db_session: Session
|
||||
) -> None:
|
||||
# Enforce lowercase path at DB operation level
|
||||
anonymous_user_path = anonymous_user_path.lower()
|
||||
|
||||
existing_entry = (
|
||||
db_session.query(TenantAnonymousUserPath).filter_by(tenant_id=tenant_id).first()
|
||||
)
|
||||
|
||||
if existing_entry:
|
||||
existing_entry.anonymous_user_path = anonymous_user_path
|
||||
|
||||
else:
|
||||
new_entry = TenantAnonymousUserPath(
|
||||
tenant_id=tenant_id, anonymous_user_path=anonymous_user_path
|
||||
)
|
||||
db_session.add(new_entry)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path: str, db_session: Session
|
||||
) -> str | None:
|
||||
result = db_session.execute(
|
||||
select(TenantAnonymousUserPath).where(
|
||||
TenantAnonymousUserPath.anonymous_user_path == anonymous_user_path
|
||||
)
|
||||
)
|
||||
result_scalar = result.scalar_one_or_none()
|
||||
if result_scalar:
|
||||
return result_scalar.tenant_id
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def validate_anonymous_user_path(path: str) -> None:
|
||||
if not path or "/" in path or not path.replace("-", "").isalnum():
|
||||
raise ValueError("Invalid path. Use only letters, numbers, and hyphens.")
|
||||
@@ -3,35 +3,124 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
@@ -103,7 +192,7 @@ async def impersonate_user(
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
@@ -114,3 +203,48 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
|
||||
@@ -39,3 +39,12 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
@@ -15,6 +15,7 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -185,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -320,3 +322,26 @@ async def submit_to_hubspot(
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
|
||||
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -51,8 +51,10 @@ def get_group_token_limit_settings(
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits(
|
||||
db_session, group_id, user
|
||||
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
|
||||
db_session=db_session,
|
||||
group_id=group_id,
|
||||
user=user,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,6 +91,7 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
@@ -10,6 +10,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
@@ -24,15 +25,10 @@ posthog = Posthog(
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
|
||||
print("API KEY", POSTHOG_API_KEY)
|
||||
print("HOST", POSTHOG_HOST)
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
try:
|
||||
print(type(distinct_id))
|
||||
print(type(event))
|
||||
print(type(properties))
|
||||
response = posthog.capture(distinct_id, event, properties)
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
print(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing Posthog event: {e}")
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -320,8 +321,6 @@ async def embed_text(
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
@@ -330,8 +329,17 @@ async def embed_text(
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
total_chars = 0
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@@ -363,8 +371,16 @@ async def embed_text(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@@ -382,13 +398,17 @@ async def embed_text(
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -440,7 +460,8 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
elif not all(embed_request.texts):
|
||||
|
||||
if not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
@@ -471,9 +492,12 @@ async def process_embed_request(
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
logger.exception(
|
||||
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
||||
@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
@@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""
|
||||
External groups may collide across sources, every source needs its own prefix.
|
||||
NOTE: the name is lowercased to handle case sensitivity for group names
|
||||
"""
|
||||
return f"{source.value}_{ext_group_name}".lower()
|
||||
|
||||
83
backend/onyx/auth/email_utils.py
Normal file
83
backend/onyx/auth/email_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -23,6 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
print("preferences_data", preferences_data)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
@@ -30,13 +31,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -33,12 +33,6 @@ class UserRole(str, Enum):
|
||||
]
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
@@ -49,4 +43,7 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import smtplib
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
@@ -32,10 +31,8 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -49,23 +46,22 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -74,22 +70,23 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -103,6 +100,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -143,6 +145,17 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -193,30 +206,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -281,7 +270,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -405,11 +393,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -428,7 +414,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
@@ -506,7 +491,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -583,51 +576,70 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
)(email=user.email)
|
||||
|
||||
data = {
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||
) -> Optional[User]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
async def destroy_token(self, token: str, user: User) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -713,30 +725,36 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -751,6 +769,15 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -161,9 +161,34 @@ def on_task_postrun(
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
|
||||
# But something is blocking set_start_method from working in the cloud unless
|
||||
# force=True. so we use force=True as a fallback.
|
||||
|
||||
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
|
||||
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
|
||||
|
||||
try:
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method exceptioned. Trying force=True..."
|
||||
)
|
||||
try:
|
||||
multiprocessing.set_start_method(
|
||||
"spawn", force=True
|
||||
) # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method force=True exceptioned even with force=True."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
@@ -335,6 +360,10 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_lock"):
|
||||
# primary_worker_lock will not exist when MULTI_TENANT is True
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
@@ -414,11 +443,21 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@@ -49,17 +49,16 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
@@ -50,17 +50,21 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -49,17 +48,18 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
95
backend/onyx/background/celery/apps/monitoring.py
Normal file
95
backend/onyx/background/celery/apps/monitoring.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
import multiprocessing
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
@@ -72,14 +73,13 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
@@ -88,12 +88,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
@@ -134,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
sender.primary_worker_lock = lock # type: ignore
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
@@ -194,6 +194,10 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
@@ -281,5 +285,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -3,12 +3,54 @@ import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -41,14 +42,21 @@ def _get_deletion_status(
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
return None
|
||||
if redis_connector.delete.fenced:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
|
||||
@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
21
backend/onyx/background/celery/configs/monitoring.py
Normal file
21
backend/onyx/background/celery/configs/monitoring.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,9 +1,17 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
@@ -13,7 +21,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -22,7 +30,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -31,7 +39,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -40,7 +48,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -49,7 +57,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -58,7 +66,17 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -67,7 +85,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -76,11 +94,25 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -17,7 +17,10 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -34,7 +37,9 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -42,11 +47,11 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -81,6 +86,8 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
@@ -109,11 +116,21 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# we need to load the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
@@ -122,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -14,10 +16,14 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -88,19 +94,19 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -128,6 +134,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -219,6 +227,43 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -234,7 +279,10 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
@@ -244,6 +292,8 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
return None
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -254,8 +304,11 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
@@ -341,5 +394,7 @@ def update_external_document_permissions_task(
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
logger.exception(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -94,19 +94,19 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -248,7 +250,10 @@ def connector_external_group_sync_generator_task(
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -18,10 +20,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -29,6 +33,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -58,6 +63,8 @@ from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -69,14 +76,18 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -86,26 +97,54 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -167,6 +206,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
debug_tenants = {
|
||||
"tenant_i-043470d740845ec56",
|
||||
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
|
||||
}
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
@@ -175,18 +218,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
@@ -209,15 +252,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
|
||||
# gather cc_pair_ids
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
@@ -226,22 +279,59 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair search settings lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_connector_credential_pair_from_id: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_last_attempt_for_cc_pair: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair should index: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
@@ -274,6 +364,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
@@ -294,13 +393,51 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing after get unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced attempt id lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -318,24 +455,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
|
||||
# turning off for the moment to see if it helps cloud stability
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing validate fences lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
# we want to run this less frequently than the overall task
|
||||
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# task_logger.info("Validating indexing fences...")
|
||||
# validate_indexing_fences(
|
||||
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -351,9 +493,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@@ -364,56 +507,20 @@ def validate_indexing_fences(
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks: set[str] = set()
|
||||
active_indexing_tasks: set[str] = set()
|
||||
indexing_worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = celery_app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if not workers:
|
||||
raise ValueError("No workers found!")
|
||||
|
||||
for worker_name in list(workers.keys()):
|
||||
if "indexing" in worker_name:
|
||||
indexing_worker_names.append(worker_name)
|
||||
|
||||
if len(indexing_worker_names) == 0:
|
||||
raise ValueError("No indexing workers found!")
|
||||
|
||||
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
|
||||
|
||||
# NOTE: each dict entry is a map of worker name to a list of tasks
|
||||
# we want sets for reserved task and active task id's to optimize
|
||||
# subsequent validation lookups
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
|
||||
if reserved_tasks is None:
|
||||
raise ValueError("inspect_indexing.reserved() returned None!")
|
||||
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_indexing_tasks.add(task["id"])
|
||||
|
||||
# get the list of active tasks
|
||||
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
|
||||
if active_tasks is None:
|
||||
raise ValueError("inspect_indexing.active() returned None!")
|
||||
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_indexing_tasks.add(task["id"])
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
active_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
@@ -424,7 +531,6 @@ def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
active_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -434,11 +540,15 @@ def validate_indexing_fence(
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. Active signal is renewed with a 5 minute TTL
|
||||
1.1 When the fence is created
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved or active list for a worker
|
||||
2. The TTL allows us to get through the transitions on fence startup
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
@@ -466,6 +576,8 @@ def validate_indexing_fence(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
@@ -501,31 +613,32 @@ def validate_indexing_fence(
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in active_tasks:
|
||||
# the celery task is active (aka currently executing)
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
|
||||
# due to gaps in our ability to check states during transitions
|
||||
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
|
||||
f"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
@@ -752,11 +865,14 @@ def connector_indexing_proxy_task(
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
"""celery tasks are forked, but forking is unstable.
|
||||
This is a thread that proxies work to a spawned task."""
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"search_settings={search_settings_id} "
|
||||
f"mp_start_method={multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
@@ -783,7 +899,6 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -795,6 +910,58 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
@@ -821,75 +988,33 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
@@ -918,7 +1043,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -926,7 +1051,14 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
@@ -998,7 +1130,17 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -1039,7 +1181,9 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1056,8 +1200,8 @@ def connector_indexing_task(
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
@@ -1075,6 +1219,7 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
@@ -1108,8 +1253,19 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
|
||||
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(default_provider.model_names or [])
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
default_provider.model_names = available_models
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
452
backend/onyx/background/celery/tasks/monitoring/tasks.py
Normal file
452
backend/onyx/background/celery/tasks/monitoring/tasks.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
|
||||
|
||||
|
||||
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
|
||||
"""Check if a metric has been emitted by checking for existence of Redis key"""
|
||||
return bool(redis_std.exists(key))
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
key: str | None # only required if we need to store that we have emitted this metric
|
||||
name: str
|
||||
value: Any
|
||||
tags: dict[str, str]
|
||||
|
||||
def log(self) -> None:
|
||||
"""Log the metric in a standardized format"""
|
||||
data = {
|
||||
"metric": self.name,
|
||||
"value": self.value,
|
||||
"tags": self.tags,
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
# Convert value to appropriate type
|
||||
float_value = (
|
||||
float(self.value) if isinstance(self.value, (int, float)) else None
|
||||
)
|
||||
int_value = int(self.value) if isinstance(self.value, int) else None
|
||||
string_value = str(self.value) if isinstance(self.value, str) else None
|
||||
bool_value = bool(self.value) if isinstance(self.value, bool) else None
|
||||
|
||||
if (
|
||||
float_value is None
|
||||
and int_value is None
|
||||
and string_value is None
|
||||
and bool_value is None
|
||||
):
|
||||
task_logger.error(
|
||||
f"Invalid metric value type: {type(self.value)} "
|
||||
f"({self.value}) for metric {self.name}."
|
||||
)
|
||||
return
|
||||
|
||||
# don't send None values over the wire
|
||||
data = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"metric_name": self.name,
|
||||
"float_value": float_value,
|
||||
"int_value": int_value,
|
||||
"string_value": string_value,
|
||||
"bool_value": bool_value,
|
||||
"tags": self.tags,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"indexing_queue_length": "indexing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name=name,
|
||||
value=celery_get_queue_length(queue, redis_celery),
|
||||
tags={"queue": name},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _build_connector_start_latency_metric(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempt: IndexAttempt,
|
||||
second_most_recent_attempt: IndexAttempt | None,
|
||||
redis_std: Redis,
|
||||
) -> Metric | None:
|
||||
if not recent_attempt.time_started:
|
||||
return None
|
||||
|
||||
# check if we already emitted a metric for this index attempt
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
# Connector start latency
|
||||
# first run case - we should start as soon as it's created
|
||||
if not second_most_recent_attempt:
|
||||
desired_start_time = cc_pair.connector.time_created
|
||||
else:
|
||||
if not cc_pair.connector.refresh_freq:
|
||||
task_logger.error(
|
||||
"Found non-initial index attempt for connector "
|
||||
"without refresh_freq. This should never happen."
|
||||
)
|
||||
return None
|
||||
|
||||
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
|
||||
seconds=cc_pair.connector.refresh_freq
|
||||
)
|
||||
|
||||
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={},
|
||||
)
|
||||
|
||||
|
||||
def _build_run_success_metric(
|
||||
cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, redis_std: Redis
|
||||
) -> Metric | None:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
if recent_attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=recent_attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
|
||||
metrics = []
|
||||
for cc_pair in cc_pairs:
|
||||
# Get most recent attempt in the last hour
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.time_created >= one_hour_ago,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
# if no metric to emit, skip
|
||||
if not recent_attempt:
|
||||
continue
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
run_success_metric = _build_run_success_metric(
|
||||
cc_pair, recent_attempt, redis_std
|
||||
)
|
||||
if run_success_metric:
|
||||
metrics.append(run_success_metric)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about document set and group syncing speed"""
|
||||
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records from the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_start_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_start_time.desc())
|
||||
).all()
|
||||
|
||||
metrics = []
|
||||
for sync_record in recent_sync_records:
|
||||
# Skip if no end time (sync still in progress)
|
||||
if not sync_record.sync_end_time:
|
||||
continue
|
||||
|
||||
# Check if we already emitted a metric for this sync record
|
||||
metric_key = (
|
||||
f"sync_speed:{sync_record.sync_type}:"
|
||||
f"{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.debug(
|
||||
f"Skipping metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate sync duration in minutes
|
||||
sync_duration_mins = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds() / 60.0
|
||||
|
||||
# Calculate sync speed (docs/min) - avoid division by zero
|
||||
sync_speed = (
|
||||
sync_record.num_docs_synced / sync_duration_mins
|
||||
if sync_duration_mins > 0
|
||||
else None
|
||||
)
|
||||
|
||||
if sync_speed is None:
|
||||
task_logger.error(
|
||||
"Something went wrong with sync speed calculation. "
|
||||
f"Sync record: {sync_record.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add sync start latency metric
|
||||
start_latency_key = (
|
||||
f"sync_start_latency:{sync_record.sync_type}"
|
||||
f":{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
task_logger.debug(
|
||||
f"Skipping start latency metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Skip other sync types
|
||||
task_logger.debug(
|
||||
f"Skipping sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} "
|
||||
f"and id {sync_record.entity_id} "
|
||||
"because it is not a document set or user group"
|
||||
)
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Start latency is negative for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
|
||||
"This is likely because the entity was updated between the time the "
|
||||
"time the sync finished and this job ran. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
- Connector run metrics (start latency, success rate)
|
||||
- Syncing speed metrics
|
||||
- Worker status and task counts
|
||||
"""
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_monitoring: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
|
||||
timeout=_MONITORING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_monitoring.acquire(blocking=False):
|
||||
task_logger.info("Skipping monitoring task because it is already running")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
if metric.key:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception("Error collecting background process metrics")
|
||||
raise e
|
||||
finally:
|
||||
if lock_monitoring.owned():
|
||||
lock_monitoring.release()
|
||||
|
||||
task_logger.info("Background monitoring task finished")
|
||||
@@ -81,19 +81,19 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
@@ -103,7 +103,10 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
@@ -127,6 +130,8 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -283,6 +288,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -28,13 +28,35 @@ class RetryDocumentIndex:
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
return self.index.delete_single(doc_id)
|
||||
def delete_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(httpx.ReadTimeout),
|
||||
wait=wait_random_exponential(multiplier=1, max=MAX_WAIT),
|
||||
stop=stop_after_delay(STOP_AFTER),
|
||||
)
|
||||
def update_single(self, doc_id: str, fields: VespaDocumentFields) -> int:
|
||||
return self.index.update_single(doc_id, fields)
|
||||
def update_single(
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.document import fetch_chunk_count_for_document
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import get_document_connector_count
|
||||
from onyx.db.document import mark_document_as_modified
|
||||
@@ -80,7 +81,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
# delete it from vespa and the db
|
||||
action = "delete"
|
||||
|
||||
chunks_affected = retry_index.delete_single(document_id)
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
chunks_affected = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
@@ -110,7 +117,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields=fields)
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
@@ -20,10 +23,12 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -49,10 +54,16 @@ from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -67,6 +78,8 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -75,6 +88,7 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -87,7 +101,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
@@ -99,17 +113,18 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, db_session, r, lock_beat, tenant_id
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
@@ -121,6 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -129,6 +145,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -148,6 +166,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -162,14 +181,21 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_vespa_sync_task - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
max_tasks: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -200,11 +226,16 @@ def try_generate_stale_document_sync_tasks(
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
tasks_remaining = max_tasks
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rc.generate_tasks(
|
||||
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
if result is None:
|
||||
continue
|
||||
@@ -218,10 +249,19 @@ def try_generate_stale_document_sync_tasks(
|
||||
)
|
||||
|
||||
total_tasks_generated += result[0]
|
||||
tasks_remaining -= result[0]
|
||||
if tasks_remaining <= 0:
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
if tasks_remaining <= 0:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks reached the task generation limit: "
|
||||
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
@@ -245,11 +285,21 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||
document_set = get_document_set_by_id(
|
||||
db_session=db_session,
|
||||
document_set_id=document_set_id,
|
||||
)
|
||||
if not document_set:
|
||||
return None
|
||||
|
||||
if document_set.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@@ -260,7 +310,9 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rds.generate_tasks(
|
||||
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -276,6 +328,13 @@ def try_generate_document_set_sync_tasks(
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -297,8 +356,9 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_group"
|
||||
fetch_user_group = cast(
|
||||
Callable[[Session, int], UserGroup | None],
|
||||
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
@@ -306,6 +366,13 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
if usergroup.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@@ -315,7 +382,9 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
result = rug.generate_tasks(
|
||||
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
@@ -331,8 +400,16 @@ def try_generate_user_group_sync_tasks(
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@@ -382,6 +459,13 @@ def monitor_document_set_taskset(
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
@@ -400,6 +484,13 @@ def monitor_document_set_taskset(
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
|
||||
@@ -433,10 +524,21 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -508,11 +610,29 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
@@ -636,15 +756,23 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -715,18 +843,21 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
@@ -736,7 +867,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
timings: dict[str, Any] = {}
|
||||
timings["start"] = time_start
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -744,55 +881,94 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return False
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
phase_start = time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 10) == 10:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
timings["queues"] = time.monotonic() - phase_start
|
||||
timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
timings["connector_deletion_ttl"] = r.ttl(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK
|
||||
)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
@@ -800,29 +976,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -830,6 +1022,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id} "
|
||||
f"timings={timings}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
@@ -876,12 +1075,26 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(document_id, fields)
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
# this code checks for and removes a per document sync key that is
|
||||
# used to block out the same doc from continualy resyncing
|
||||
# a quick hack that is only needed for production issues
|
||||
# redis_syncing_key = RedisConnectorCredentialPair.make_redis_syncing_key(
|
||||
# document_id
|
||||
# )
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
||||
15
backend/onyx/background/celery/versioned_apps/monitoring.py
Normal file
15
backend/onyx/background/celery/versioned_apps/monitoring.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -4,9 +4,10 @@ not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing as mp
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
@@ -46,7 +47,9 @@ def _initializer(
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
SqlEngine.init_engine(
|
||||
pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True
|
||||
)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
@@ -63,7 +66,7 @@ class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: Optional["Process"] = None
|
||||
process: Optional["SpawnProcess"] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@@ -131,7 +134,10 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
# this approach allows us to always "spawn" a new process regardless of
|
||||
# get_start_method's current setting
|
||||
ctx = mp.get_context("spawn")
|
||||
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
@@ -11,21 +12,25 @@ from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -74,7 +79,8 @@ def _get_connector_runner(
|
||||
# it will never succeed
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
attempt.connector_credential_pair.id, db_session
|
||||
db_session=db_session,
|
||||
cc_pair_id=attempt.connector_credential_pair.id,
|
||||
)
|
||||
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
|
||||
update_connector_credential_pair(
|
||||
@@ -90,13 +96,54 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class RunIndexingContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
@@ -110,61 +157,76 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
)
|
||||
|
||||
if index_attempt_start.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
# search_settings = index_attempt_start.search_settings
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
ctx = RunIndexingContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=index_attempt_start.from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=(
|
||||
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
)
|
||||
|
||||
search_settings = index_attempt.search_settings
|
||||
last_successful_index_time = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings,
|
||||
callback=callback,
|
||||
primary_index_name=ctx.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
attempt_id=index_attempt_id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
|
||||
)
|
||||
|
||||
last_successful_index_time = (
|
||||
earliest_index_time
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = OnyxTracer()
|
||||
@@ -172,8 +234,8 @@ def _run_indexing(
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
batch_num = 0
|
||||
@@ -189,19 +251,31 @@ def _run_indexing(
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
cc_pair_loop: ConnectorCredentialPair | None = None
|
||||
index_attempt_loop: IndexAttempt | None = None
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_loop_start = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop_start:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt_loop_start,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
@@ -218,27 +292,43 @@ def _run_indexing(
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt.status}"
|
||||
if (
|
||||
(
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -258,15 +348,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -276,16 +366,17 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
@@ -299,34 +390,36 @@ def _run_indexing(
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
|
||||
run_end_dt = window_end
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
if ctx.is_primary:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
@@ -340,24 +433,30 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
or (
|
||||
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
or (
|
||||
index_attempt_loop is not None
|
||||
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
@@ -379,56 +478,58 @@ def _run_indexing(
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt, db_session)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
@@ -448,27 +549,35 @@ def run_indexing_entrypoint(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
|
||||
@@ -12,17 +12,19 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.build import default_build_system_message
|
||||
from onyx.chat.prompt_builder.build import default_build_user_message
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -206,27 +208,14 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
) or ([], [])
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
@@ -263,11 +252,13 @@ class Answer:
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
raw_user_text=self.question,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
|
||||
@@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -261,13 +260,8 @@ class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
citation_config: CitationConfig
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
@@ -276,20 +270,6 @@ class AnswerStyleConfig(BaseModel):
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
|
||||
@@ -302,6 +302,11 @@ def stream_chat_message_objects(
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
# a string which represents the history of a conversation. Used in cases like
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages.
|
||||
# NOTE: is not stored in the database at all.
|
||||
single_message_history: str | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -707,6 +712,7 @@ def stream_chat_message_objects(
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@@ -42,11 +43,22 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
single_message_history: str | None = None,
|
||||
) -> HumanMessage:
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=single_message_history)
|
||||
if single_message_history
|
||||
else ""
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
task_prompt=prompt_config.task_prompt, user_query=user_query
|
||||
history_block=history_block,
|
||||
task_prompt=prompt_config.task_prompt,
|
||||
user_query=user_query,
|
||||
)
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
@@ -64,7 +76,8 @@ class AnswerPromptBuilder:
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
raw_user_text: str,
|
||||
raw_user_query: str,
|
||||
raw_user_uploaded_files: list[InMemoryChatFile],
|
||||
single_message_history: str | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
@@ -83,10 +96,6 @@ class AnswerPromptBuilder:
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
@@ -95,7 +104,10 @@ class AnswerPromptBuilder:
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
self.raw_user_message = raw_user_text
|
||||
# used for building a new prompt after a tool-call
|
||||
self.raw_user_query = raw_user_query
|
||||
self.raw_user_uploaded_files = raw_user_uploaded_files
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
@@ -144,9 +144,7 @@ def build_citations_user_message(
|
||||
)
|
||||
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=history_message) + "\n"
|
||||
if history_message
|
||||
else ""
|
||||
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
|
||||
@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
|
||||
@@ -21,20 +21,19 @@ class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.citation_order: list[int] = [] # order of citations in the LLM output
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
@@ -93,29 +92,31 @@ class CitationProcessor:
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
@@ -134,8 +135,8 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -151,13 +152,13 @@ class CitationProcessor:
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
@@ -167,7 +168,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -176,7 +176,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.build import LLMCall
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.message import build_tool_message
|
||||
@@ -62,7 +62,7 @@ class ToolResponseHandler:
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
@@ -76,7 +76,7 @@ class ToolResponseHandler:
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -95,7 +95,7 @@ class ToolResponseHandler:
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_message,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
|
||||
@@ -17,6 +17,7 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
@@ -54,10 +55,17 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@@ -92,6 +100,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -145,7 +154,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -184,6 +193,27 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
@@ -251,6 +281,11 @@ try:
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 1024
|
||||
|
||||
DB_YIELD_PER_DEFAULT = 64
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
@@ -347,12 +382,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -503,6 +543,9 @@ try:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -542,7 +585,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
|
||||
|
||||
@@ -36,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -45,6 +47,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
@@ -74,13 +77,19 @@ KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -134,9 +143,11 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
DISCORD = "discord"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -240,6 +251,7 @@ class OnyxCeleryQueues:
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -249,6 +261,9 @@ class OnyxCeleryQueues:
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -263,6 +278,7 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
@@ -273,6 +289,7 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -294,7 +311,9 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
||||
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from pyairtable.api.types import RecordDict
|
||||
from pyairtable.models.schema import TableSchema
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def _process_record(
|
||||
self,
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
"""
|
||||
table_id = table_schema.id
|
||||
table_name = table_schema.name
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user