mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
137 Commits
folder
...
my_documen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da86610022 | ||
|
|
0027759dbf | ||
|
|
595ef152d2 | ||
|
|
083d669d1b | ||
|
|
3ac31136b2 | ||
|
|
a73a438d95 | ||
|
|
c0770481e8 | ||
|
|
c27d13c07f | ||
|
|
ab34c4e772 | ||
|
|
66f9124135 | ||
|
|
8f0fb70bbf | ||
|
|
ef5e5c80bb | ||
|
|
03acb6587a | ||
|
|
d1ec72b5e5 | ||
|
|
3b214133a8 | ||
|
|
2232702e99 | ||
|
|
8108ff0a4b | ||
|
|
f64e78e986 | ||
|
|
08312a4394 | ||
|
|
92add655e0 | ||
|
|
d64464ca7c | ||
|
|
ccd3983802 | ||
|
|
240f3e4fff | ||
|
|
1291b3d930 | ||
|
|
d05f1997b5 | ||
|
|
aa2e2a62b9 | ||
|
|
174e5968f8 | ||
|
|
1f27606e17 | ||
|
|
60355b84c1 | ||
|
|
680ab9ea30 | ||
|
|
c2447dbb1c | ||
|
|
52bad522f8 | ||
|
|
63e5e58313 | ||
|
|
2643782e30 | ||
|
|
3eb72e5c1d | ||
|
|
9b65c23a7e | ||
|
|
b43a8e48c6 | ||
|
|
1955c1d67b | ||
|
|
3f92ed9d29 | ||
|
|
618369f4a1 | ||
|
|
2783216781 | ||
|
|
bec0f9fb23 | ||
|
|
97a03e7fc8 | ||
|
|
8d6e8269b7 | ||
|
|
9ce2c6c517 | ||
|
|
2ad8bdbc65 | ||
|
|
a83c9b40d5 | ||
|
|
340fab1375 | ||
|
|
3ec338307f | ||
|
|
27acd3387a | ||
|
|
d14ef431a7 | ||
|
|
9bffeb65af | ||
|
|
f4806da653 | ||
|
|
e2700b2bbd | ||
|
|
fc81a3fb12 | ||
|
|
2203cfabea | ||
|
|
f4050306d6 | ||
|
|
2d960a477f | ||
|
|
8837b8ea79 | ||
|
|
3dfb214f73 | ||
|
|
18d7262608 | ||
|
|
09b879ee73 | ||
|
|
aaa668c963 | ||
|
|
edb877f4bc | ||
|
|
eb369caefb | ||
|
|
b9567eabd7 | ||
|
|
13bbf67091 | ||
|
|
457a4c73f0 | ||
|
|
ce37688b5b | ||
|
|
4e2c90f4af | ||
|
|
513dd8a319 | ||
|
|
71c5043832 | ||
|
|
64b6f15e95 | ||
|
|
35022f5f09 | ||
|
|
0d44014c16 | ||
|
|
1b9e9f48fa | ||
|
|
05fb5aa27b | ||
|
|
3b645b72a3 | ||
|
|
fe770b5c3a | ||
|
|
1eaf885f50 | ||
|
|
a187aa508c | ||
|
|
aa4bfa2a78 | ||
|
|
9011b8a139 | ||
|
|
59c774353a | ||
|
|
b458d504af | ||
|
|
f83e7bfcd9 | ||
|
|
4d2e26ce4b | ||
|
|
817fdc1f36 | ||
|
|
e9b10e8b41 | ||
|
|
a0fa4adb60 | ||
|
|
ca9ba925bd | ||
|
|
833cc5c97c | ||
|
|
23ecf654ed | ||
|
|
ddc6a6d2b3 | ||
|
|
571c8ece32 | ||
|
|
884bdb4b01 | ||
|
|
b3ecf0d59f | ||
|
|
f56fda27c9 | ||
|
|
b1e4d4ea8d | ||
|
|
8db6d49fe5 | ||
|
|
28598694b1 | ||
|
|
b5d0df90b9 | ||
|
|
48be6338ec | ||
|
|
ed9014f03d | ||
|
|
2dd51230ed | ||
|
|
8b249cbe63 | ||
|
|
6b50f86cd2 | ||
|
|
bd2805b6df | ||
|
|
2847ab003e | ||
|
|
1df6a506ec | ||
|
|
f1541d1fbe | ||
|
|
dd0c4b64df | ||
|
|
788b3015bc | ||
|
|
cbbf10f450 | ||
|
|
d954914a0a | ||
|
|
bee74ac360 | ||
|
|
29ef64272a | ||
|
|
01bf6ee4b7 | ||
|
|
0502417cbe | ||
|
|
d0483dd269 | ||
|
|
eefa872d60 | ||
|
|
3f3d4da611 | ||
|
|
469068052e | ||
|
|
9032b05606 | ||
|
|
334bc6be8c | ||
|
|
814f97c2c7 | ||
|
|
4f5a2b47c4 | ||
|
|
f545508268 | ||
|
|
590986ec65 | ||
|
|
531bab5409 | ||
|
|
29c44007c4 | ||
|
|
d388643a04 | ||
|
|
8a422683e3 | ||
|
|
ddc0230d68 | ||
|
|
6711e91dbf | ||
|
|
cff2346db5 | ||
|
|
8d3fad1f12 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,24 +6,6 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,6 +66,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -8,18 +8,29 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
build-amd64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -27,24 +38,80 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and Push Multi-arch Manifest
|
||||
run: |
|
||||
docker buildx create --use
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
@@ -53,3 +120,4 @@ jobs:
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
14
.github/workflows/pr-chromatic-tests.yml
vendored
14
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -15,7 +15,12 @@ jobs:
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -196,7 +201,12 @@ jobs:
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
3
.github/workflows/pr-integration-tests.yml
vendored
3
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,8 +20,7 @@ env:
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
10
.github/workflows/pr-python-connector-tests.yml
vendored
10
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,7 +26,15 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
18
README.md
18
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -9,3 +9,4 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
@@ -1,39 +1,49 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
import os
|
||||
import ssl
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
if USE_IAM_AUTH:
|
||||
if not os.path.exists(SSL_CERT_FILE):
|
||||
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
@@ -49,20 +59,12 @@ def include_object(
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
@@ -117,11 +115,25 @@ def do_run_migrations(
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION_NAME
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
# Get IAM authentication token
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
|
||||
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
129
backend/alembic/versions/25d86cbfce78_add_my_documents.py
Normal file
129
backend/alembic/versions/25d86cbfce78_add_my_documents.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import datetime
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25d86cbfce78"
|
||||
down_revision = "c0aab6edb6dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_folder table with additional 'display_priority' field
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"parent_id", sa.Integer(), sa.ForeignKey("user_folder.id"), nullable=True
|
||||
),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column("created_at", sa.DateTime(), default=datetime.datetime.utcnow),
|
||||
)
|
||||
|
||||
# Migrate data from chat_folder to user_folder
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO user_folder (id, user_id, name, display_priority, created_at)
|
||||
SELECT id, user_id, name, display_priority, CURRENT_TIMESTAMP FROM chat_folder
|
||||
"""
|
||||
)
|
||||
|
||||
# Update chat_session table to reference user_folder instead of chat_folder
|
||||
op.drop_constraint(
|
||||
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"folder_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
existing_nullable=True,
|
||||
existing_server_default=None,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_folder_id_user_folder",
|
||||
"chat_session",
|
||||
"user_folder",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# Drop the chat_folder table
|
||||
op.drop_table("chat_folder")
|
||||
|
||||
# Create user_file table
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"parent_folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate chat_folder table
|
||||
op.create_table(
|
||||
"chat_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.UUID(),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
)
|
||||
|
||||
# Migrate data back from user_folder to chat_folder
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO chat_folder (id, user_id, name, display_priority)
|
||||
SELECT id, user_id, name, display_priority FROM user_folder
|
||||
WHERE id IN (SELECT DISTINCT folder_id FROM chat_session WHERE folder_id IS NOT NULL)
|
||||
"""
|
||||
)
|
||||
|
||||
# Update chat_session table to reference chat_folder again
|
||||
op.drop_constraint(
|
||||
"fk_chat_session_folder_id_user_folder", "chat_session", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"folder_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
existing_nullable=True,
|
||||
existing_server_default=None,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_session_chat_folder_fk",
|
||||
"chat_session",
|
||||
"chat_folder",
|
||||
["folder_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# Drop the user_file table
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""properly_cascade
|
||||
|
||||
Revision ID: 35e518e0ddf4
|
||||
Revises: 91a0a4d62b14
|
||||
Create Date: 2024-09-20 21:24:04.891018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e518e0ddf4"
|
||||
down_revision = "91a0a4d62b14"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Update chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add CASCADE delete for tool_call foreign key
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert tool_call foreign key constraint
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Milestone
|
||||
|
||||
Revision ID: 91a0a4d62b14
|
||||
Revises: dab04867cd88
|
||||
Create Date: 2024-12-13 19:03:30.947551
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91a0a4d62b14"
|
||||
down_revision = "dab04867cd88"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""delete workspace
|
||||
|
||||
Revision ID: c0aab6edb6dd
|
||||
Revises: 35e518e0ddf4
|
||||
Create Date: 2024-12-17 14:37:07.660631
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0aab6edb6dd"
|
||||
down_revision = "35e518e0ddf4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = connector_specific_config - 'workspace'
|
||||
WHERE source = 'SLACK'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
import json
|
||||
from sqlalchemy import text
|
||||
from slack_sdk import WebClient
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Fetch all Slack credentials
|
||||
creds_result = conn.execute(
|
||||
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
|
||||
)
|
||||
all_slack_creds = creds_result.fetchall()
|
||||
if not all_slack_creds:
|
||||
return
|
||||
|
||||
for cred_row in all_slack_creds:
|
||||
credential_id, credential_json = cred_row
|
||||
|
||||
credential_json = (
|
||||
credential_json.tobytes().decode("utf-8")
|
||||
if isinstance(credential_json, memoryview)
|
||||
else credential_json.decode("utf-8")
|
||||
)
|
||||
credential_data = json.loads(credential_json)
|
||||
slack_bot_token = credential_data.get("slack_bot_token")
|
||||
if not slack_bot_token:
|
||||
print(
|
||||
f"No slack_bot_token found for credential {credential_id}. "
|
||||
"Your Slack connector will not function until you upgrade and provide a valid token."
|
||||
)
|
||||
continue
|
||||
|
||||
client = WebClient(token=slack_bot_token)
|
||||
try:
|
||||
auth_response = client.auth_test()
|
||||
workspace = auth_response["url"].split("//")[1].split(".")[0]
|
||||
|
||||
# Update only the connectors linked to this credential
|
||||
# (and which are Slack connectors).
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE connector AS c
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{{workspace}}',
|
||||
to_jsonb('{workspace}'::text)
|
||||
)
|
||||
FROM connector_credential_pair AS ccp
|
||||
WHERE ccp.connector_id = c.id
|
||||
AND c.source = 'SLACK'
|
||||
AND ccp.credential_id = {credential_id}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
|
||||
)
|
||||
print("This connector will no longer work until you upgrade.")
|
||||
continue
|
||||
@@ -47,3 +47,11 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", ""
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=set_curator_request.user_id,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -97,19 +96,21 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
prefix="/auth/saml",
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -62,14 +82,7 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -77,10 +90,14 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -102,82 +119,151 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -192,8 +278,11 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -203,6 +292,11 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -210,8 +304,6 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -223,6 +315,7 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -12,15 +13,23 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -114,3 +123,48 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -39,3 +39,8 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
@@ -3,15 +3,19 @@ import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -20,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -35,22 +40,27 @@ from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
@@ -122,6 +132,17 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
@@ -165,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -267,3 +289,59 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
|
||||
async def submit_to_hubspot(
|
||||
email: str, referral_source: str | None, request: Request
|
||||
) -> None:
|
||||
if not HUBSPOT_TRACKING_URL:
|
||||
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
|
||||
return
|
||||
|
||||
# HubSpot tracking cookie
|
||||
hubspot_cookie = request.cookies.get("hubspotutk")
|
||||
|
||||
# IP address
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
data = {
|
||||
"fields": [
|
||||
{"name": "email", "value": email},
|
||||
{"name": "referral_source", "value": referral_source or ""},
|
||||
],
|
||||
"context": {
|
||||
"hutk": hubspot_cookie,
|
||||
"ipAddress": ip_address,
|
||||
"pageUri": str(request.url),
|
||||
"pageName": "User Registration",
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,6 +91,7 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
34
backend/ee/onyx/utils/telemetry.py
Normal file
34
backend/ee/onyx/utils/telemetry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
try:
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -320,8 +321,6 @@ async def embed_text(
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
@@ -330,8 +329,17 @@ async def embed_text(
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
total_chars = 0
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@@ -363,8 +371,16 @@ async def embed_text(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@@ -382,13 +398,17 @@ async def embed_text(
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -440,7 +460,8 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
elif not all(embed_request.texts):
|
||||
|
||||
if not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
@@ -471,9 +492,12 @@ async def process_embed_request(
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
logger.exception(
|
||||
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
||||
@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
|
||||
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
|
||||
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
|
||||
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
80
backend/onyx/auth/email_utils.py
Normal file
80
backend/onyx/auth/email_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Workspace"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -4,6 +4,8 @@ from typing import cast
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from onyx.configs.constants import NO_AUTH_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.key_value_store.store import KeyValueStore
|
||||
from onyx.key_value_store.store import KvKeyNotFoundError
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -28,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@onyx.app",
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -49,4 +49,7 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -52,19 +50,17 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -72,6 +68,9 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
@@ -86,8 +85,9 @@ from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -99,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -139,6 +144,20 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled() -> bool:
|
||||
if MULTI_TENANT:
|
||||
return False
|
||||
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -189,30 +208,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -225,17 +220,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
if request is not None
|
||||
else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -268,7 +272,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -278,7 +281,37 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to basic security guidelines
|
||||
if len(password) < 12:
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must be at least 12 characters long."
|
||||
)
|
||||
if len(password) > 64:
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must not exceed 64 characters."
|
||||
)
|
||||
if not any(char.isupper() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one uppercase letter."
|
||||
)
|
||||
if not any(char.islower() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one lowercase letter."
|
||||
)
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one number."
|
||||
)
|
||||
if not any(char in PASSWORD_SPECIAL_CHARS for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
self,
|
||||
@@ -293,17 +326,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
referral_source = (
|
||||
getattr(request.state, "referral_source", None) if request else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -365,6 +399,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -418,6 +453,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
request=request,
|
||||
)
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
user_count = await get_user_count()
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if user_count == 1:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.USER_SIGNED_UP,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.MULTIPLE_USERS,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
@@ -428,7 +496,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -449,7 +525,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
@@ -510,7 +586,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
@@ -546,9 +622,7 @@ def get_database_strategy(
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@@ -635,30 +709,36 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -673,6 +753,14 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled()
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -3,11 +3,12 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
@@ -21,6 +22,7 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
@@ -34,8 +36,11 @@ from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -56,8 +61,8 @@ def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
@@ -257,7 +262,8 @@ def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
@@ -346,26 +352,36 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
loglevel: int,
|
||||
logfile: str | None,
|
||||
format: str,
|
||||
colorize: bool,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = []
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
# Define the log format
|
||||
log_format = (
|
||||
"%(levelname)-8s %(asctime)s %(filename)15s:%(lineno)-4d: %(name)s %(message)s"
|
||||
)
|
||||
|
||||
# Set up the root handler
|
||||
root_handler = logging.StreamHandler()
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
root_logger.addHandler(root_handler)
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
@@ -373,19 +389,23 @@ def on_setup_logging(
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
# Configure the task logger
|
||||
task_logger.handlers = []
|
||||
|
||||
task_handler = logging.StreamHandler()
|
||||
task_handler.addFilter(TenantContextFilter())
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
task_logger.addHandler(task_handler)
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_handler.addFilter(TenantContextFilter())
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
@@ -398,6 +418,61 @@ def on_setup_logging(
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if not MULTI_TENANT:
|
||||
record.name = ""
|
||||
return True
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id:
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
|
||||
record.name = f"[t:{tenant_id}]"
|
||||
else:
|
||||
record.name = ""
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -44,18 +43,18 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
logger.info("Starting task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
logger.info("Fetching all IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
@@ -70,7 +69,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
@@ -83,7 +82,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
logger.info(f"Processing new item: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
@@ -129,11 +128,10 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
@@ -155,10 +153,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -61,13 +61,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -60,15 +60,21 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -60,13 +60,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -84,14 +85,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
@@ -194,6 +195,10 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
|
||||
@@ -1,12 +1,56 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -23,3 +67,96 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
|
||||
|
||||
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to find a task for a particular queue in redis.
|
||||
It is priority aware and knows how to look through the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic.
|
||||
|
||||
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
|
||||
|
||||
Returns true if the id is in the queue, False if not.
|
||||
"""
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
if task_dict.get("headers", {}).get("id") == task_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -4,55 +4,86 @@ from typing import Any
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": OnyxCeleryPriority.LOWEST},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -45,7 +47,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
@@ -76,11 +78,13 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during connector deletion check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
@@ -131,14 +135,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
@@ -175,7 +179,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -18,6 +20,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -88,10 +91,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -99,7 +102,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
@@ -128,6 +131,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -219,6 +224,43 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -254,8 +296,11 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
|
||||
@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -15,10 +19,13 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -26,6 +33,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -66,14 +74,18 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -84,25 +96,68 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
# diagnostic logging for lock errors
|
||||
name = self.redis_lock.name
|
||||
ttl = self.redis_client.ttl(name)
|
||||
locked = self.redis_lock.locked()
|
||||
owned = self.redis_lock.owned()
|
||||
local_token: str | None = self.redis_lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = self.redis_client.get(self.redis_lock.name)
|
||||
if remote_token_raw:
|
||||
remote_token_bytes = cast(bytes, remote_token_raw)
|
||||
remote_token = remote_token_bytes.decode("utf-8")
|
||||
else:
|
||||
remote_token = None
|
||||
|
||||
logger.warning(
|
||||
f"IndexingCallback - lock diagnostics: "
|
||||
f"name={name} "
|
||||
f"locked={locked} "
|
||||
f"owned={owned} "
|
||||
f"local_token={local_token} "
|
||||
f"remote_token={remote_token} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -162,11 +217,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -271,7 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
search_settings_instance,
|
||||
reindex,
|
||||
db_session,
|
||||
r,
|
||||
redis_client,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
@@ -286,7 +349,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -304,12 +369,27 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during indexing check")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
@@ -320,9 +400,157 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed."
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
@@ -469,6 +697,7 @@ def try_creating_indexing_task(
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
@@ -502,13 +731,14 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
@@ -540,7 +770,6 @@ def connector_indexing_proxy_task(
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -563,7 +792,6 @@ def connector_indexing_proxy_task(
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -571,7 +799,6 @@ def connector_indexing_proxy_task(
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -582,11 +809,62 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -609,79 +887,36 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -707,7 +942,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -715,13 +950,20 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
index_attempt_id: int | None,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
@@ -787,7 +1029,17 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -828,7 +1080,9 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -864,6 +1118,7 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
@@ -877,6 +1132,7 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
@@ -896,8 +1152,19 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
@@ -906,7 +1173,6 @@ def connector_indexing_task(
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -122,11 +122,13 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during pruning check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
@@ -308,7 +311,7 @@ def connector_pruning_generator_task(
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
@@ -324,7 +327,7 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -129,16 +129,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
@@ -157,15 +154,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.exception(f"Unexpected exception: doc={document_id}")
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
@@ -176,7 +170,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -19,6 +20,7 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -86,13 +88,14 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -100,7 +103,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
@@ -156,11 +159,15 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during vespa metadata sync")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
@@ -630,15 +637,23 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -709,11 +724,14 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
@@ -730,6 +748,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -758,32 +777,43 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
# scan and monitor activity to completion
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
@@ -794,28 +824,21 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -824,6 +847,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -873,13 +898,9 @@ def vespa_metadata_sync_task(
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
@@ -897,14 +918,13 @@ def vespa_metadata_sync_task(
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
f"Unexpected exception during vespa metadata sync: doc={document_id}"
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
|
||||
@@ -11,8 +11,10 @@ from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
@@ -34,6 +36,7 @@ from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,6 +91,35 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
@@ -236,7 +268,9 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -256,15 +290,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -274,7 +308,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
@@ -396,6 +430,15 @@ def _run_indexing(
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
|
||||
@@ -31,6 +31,8 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
@@ -53,6 +55,9 @@ from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -117,6 +122,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -356,6 +362,31 @@ def stream_chat_message_objects(
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_user_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
||||
assistant_id=persona.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if just_hit_multi_assistant_milestone:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
properties=None,
|
||||
)
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
|
||||
@@ -65,7 +65,7 @@ class CitationProcessor:
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
pass
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
@@ -57,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@@ -91,6 +95,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -144,6 +149,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -174,11 +180,33 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
|
||||
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
@@ -342,12 +370,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -483,6 +516,21 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
# allow for custom error messages for different errors returned by litellm
|
||||
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
|
||||
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
|
||||
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
|
||||
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
|
||||
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
|
||||
)
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
|
||||
try:
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
|
||||
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
|
||||
@@ -63,6 +63,10 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
|
||||
@@ -15,6 +15,9 @@ ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
NO_AUTH_USER_ID = "__no_auth_user__"
|
||||
NO_AUTH_USER_EMAIL = "anonymous@onyx.app"
|
||||
|
||||
# For chunking/processing chunks
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
@@ -33,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -46,6 +51,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
SSL_CERT_FILE = "bundle.pem"
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
@@ -77,6 +83,9 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -133,6 +142,7 @@ class DocumentSource(str, Enum):
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -170,6 +180,10 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
SEARCH = "Search"
|
||||
@@ -207,9 +221,23 @@ class FileOrigin(str, Enum):
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
MY_DOCUMENTS = "my_documents"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class MilestoneRecordType(str, Enum):
|
||||
TENANT_CREATED = "tenant_created"
|
||||
USER_SIGNED_UP = "user_signed_up"
|
||||
MULTIPLE_USERS = "multiple_users"
|
||||
VISITED_ADMIN_PAGE = "visited_admin_page"
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
@@ -252,6 +280,11 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
|
||||
268
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
268
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table_id = table.id
|
||||
# due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778,
|
||||
# we can't user the `iterate()` method - we need to get everything up front
|
||||
# this also means we can't handle tables that won't fit in memory
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
# have to get the name from the schema, since the table object will
|
||||
# give back the ID instead of the name if the ID is used to create
|
||||
# the table object
|
||||
table_name = table_schema.name
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Possibly retrieve the primary field's value
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
record_document = Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
record_documents.append(record_document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
index_recursively: bool = False,
|
||||
cql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -82,23 +99,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
# if nothing is provided, we will fetch all pages
|
||||
cql_page_query = "type=page"
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
the order shouldn't matter
|
||||
However, we use elif to ensure that only of the following is enforced
|
||||
"""
|
||||
base_cql_page_query = "type=page"
|
||||
if cql_query:
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
base_cql_page_query = cql_query
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
|
||||
else:
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
base_cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
uri_safe_space = quote(space)
|
||||
base_cql_page_query += f" and space='{uri_safe_space}'"
|
||||
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
self.base_cql_page_query = base_cql_page_query
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
@@ -126,6 +145,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
@@ -205,11 +251,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
page_query = self._construct_page_query(start, end)
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -228,11 +278,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
@@ -248,17 +297,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
def poll_source(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -269,7 +313,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -294,10 +338,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
|
||||
@@ -153,7 +153,7 @@ class OnyxConfluence(Confluence):
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
if e.response is not None and e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.text_processing import is_valid_email
|
||||
@@ -71,3 +72,10 @@ def process_in_batches(
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
|
||||
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
||||
@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 1
|
||||
page = 0
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
|
||||
@@ -3,20 +3,19 @@ import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
@@ -166,6 +125,15 @@ def _process_egnyte_file(
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||
egnyte_domain: str = Field(
|
||||
title="Egnyte Domain",
|
||||
description=(
|
||||
"The domain for the Egnyte instance "
|
||||
"(e.g. 'company' for company.egnyte.com)"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
||||
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"domain": oauth_kwargs.egnyte_domain,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
@@ -260,9 +236,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
url_encoded_path = quote(path or "")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
@@ -315,12 +292,12 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
url_encoded_path = quote(file["path"])
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
@@ -103,6 +104,7 @@ def identify_connector_class(
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
"""
|
||||
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logger.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
logger.info(f"Impersonating user {user_email}")
|
||||
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - include_my_drives is true
|
||||
# - the current user's email is in the requested emails
|
||||
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
|
||||
@@ -2,6 +2,8 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
class AdditionalOauthKwargs(BaseModel):
|
||||
# if overridden, all fields should be str type
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -7,16 +7,23 @@ from typing import cast
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
)
|
||||
|
||||
|
||||
class LinearConnector(LoadConnector, PollConnector):
|
||||
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.linear_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(
|
||||
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||
) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||
return (
|
||||
f"https://linear.app/oauth/authorize"
|
||||
f"?client_id={LINEAR_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(
|
||||
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
base_domain, DocumentSource.LINEAR.value
|
||||
),
|
||||
"client_id": LINEAR_CLIENT_ID,
|
||||
"client_secret": LINEAR_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url="https://api.linear.app/oauth/token",
|
||||
data=data,
|
||||
headers=headers,
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
return {
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
if "linear_api_key" in credentials:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
elif "access_token" in credentials:
|
||||
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||
else:
|
||||
# May need to handle case in the future if the OAuth flow expires
|
||||
raise ConnectorMissingCredentialError("Linear")
|
||||
|
||||
return None
|
||||
|
||||
def _process_issues(
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -19,24 +15,25 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.utils import extract_dict_text
|
||||
from onyx.connectors.salesforce.doc_conversion import extract_section
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,200 +41,170 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sf_client: Salesforce | None = None
|
||||
self._sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.sf_client = Salesforce(
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
@property
|
||||
def sf_client(self) -> Salesforce:
|
||||
if self._sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
return self._sf_client
|
||||
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
)
|
||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
||||
return name
|
||||
except Exception:
|
||||
logger.warning(f"Couldnt find name for object id: {id}")
|
||||
return "Null User"
|
||||
def _extract_primary_owners(
|
||||
self, sf_object: SalesforceObject
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
return None
|
||||
if not (last_modified_by_name := last_modified_by.data.get("Name")):
|
||||
return None
|
||||
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
|
||||
return primary_owners
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, object_dict: dict[str, Any]
|
||||
self, sf_object: SalesforceObject
|
||||
) -> Document:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{self.sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
||||
)
|
||||
]
|
||||
|
||||
sections = [extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(extract_section(child_object, base_url))
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=extracted_primary_owners,
|
||||
primary_owners=self._extract_primary_owners(sf_object),
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = self.sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if self._is_valid_child_object(child_relationship):
|
||||
children_objects.append(
|
||||
{
|
||||
"relationship_name": child_relationship["relationshipName"],
|
||||
"object_type": child_relationship["childSObject"],
|
||||
}
|
||||
)
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
field.get("name")
|
||||
for field in object_description["fields"]
|
||||
if field.get("type", "base64") != "base64"
|
||||
]
|
||||
|
||||
return fields
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
else:
|
||||
query += query_addition
|
||||
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
|
||||
yield query
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
init_db()
|
||||
all_object_types: set[str] = set(self.parent_object_list)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
||||
logger.debug(f"Parent object types: {self.parent_object_list}")
|
||||
|
||||
# This takes like 20 seconds
|
||||
for parent_object_type in self.parent_object_list:
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
logger.info(
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
child_types = get_all_children_of_sf_type(
|
||||
self.sf_client, parent_object_type
|
||||
)
|
||||
all_object_types.update(child_types)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
)
|
||||
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||
logger.debug(f"All object types: {all_object_types}")
|
||||
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
# checkpoint - we've found all object types, now time to fetch the data
|
||||
logger.info("Starting to fetch CSVs for all object types")
|
||||
# This takes like 30 minutes first time and <2 minutes for updates
|
||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||
sf_client=self.sf_client,
|
||||
object_types=all_object_types,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
# This takes like 10 seconds
|
||||
# This is for testing the rest of the functionality if data has
|
||||
# already been fetched and put in sqlite
|
||||
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
||||
# for object_type in self.parent_object_list:
|
||||
# updated_ids.update(list(find_ids_by_type(object_type)))
|
||||
|
||||
# This takes 10-70 minutes first time (idk why the range is so big)
|
||||
total_types = len(object_type_to_csv_path)
|
||||
logger.info(f"Starting to process {total_types} object types")
|
||||
|
||||
for i, (object_type, csv_paths) in enumerate(
|
||||
object_type_to_csv_path.items(), 1
|
||||
):
|
||||
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
||||
# If path is None, it means it failed to fetch the csv
|
||||
if csv_paths is None:
|
||||
continue
|
||||
# Go through each csv path and use it to update the db
|
||||
for csv_path in csv_paths:
|
||||
logger.debug(f"Updating {object_type} with {csv_path}")
|
||||
new_ids = update_sf_db_with_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
)
|
||||
updated_ids.update(new_ids)
|
||||
logger.debug(
|
||||
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||
)
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_path)
|
||||
|
||||
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||
logger.info(
|
||||
f"Starting to process parent objects of types: {self.parent_object_list}"
|
||||
)
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_processed = 0
|
||||
# Takes 15-20 seconds per batch
|
||||
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
||||
updated_ids=list(updated_ids),
|
||||
parent_types=self.parent_object_list,
|
||||
):
|
||||
logger.info(
|
||||
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
if not (parent_object := get_record(parent_id, parent_type)):
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
docs_to_yield.append(
|
||||
self._convert_object_instance_to_document(parent_object)
|
||||
)
|
||||
docs_processed += 1
|
||||
|
||||
if len(docs_to_yield) >= self.batch_size:
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
@@ -245,26 +212,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
return self._fetch_from_salesforce(start=start, end=end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
for instance_dict in query_result["records"]
|
||||
@@ -274,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
||||
)
|
||||
import time
|
||||
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
@@ -285,5 +246,20 @@ if __name__ == "__main__":
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
start_time = time.time()
|
||||
doc_count = 0
|
||||
section_count = 0
|
||||
text_count = 0
|
||||
for doc_batch in connector.load_from_state():
|
||||
doc_count += len(doc_batch)
|
||||
print(f"doc_count: {doc_count}")
|
||||
for doc in doc_batch:
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
text_count += len(section.text)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Doc count: {doc_count}")
|
||||
print(f"Section count: {section_count}")
|
||||
print(f"Text count: {text_count}")
|
||||
print(f"Time taken: {end_time - start_time}")
|
||||
|
||||
156
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
156
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: dict | list) -> dict | list:
|
||||
"""Clean and transform Salesforce API response data by recursively:
|
||||
1. Extracting records from the response if present
|
||||
2. Merging attributes into the main dictionary
|
||||
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
|
||||
4. Removing '__c' suffix from custom field names
|
||||
5. Removing None values and empty containers
|
||||
|
||||
Args:
|
||||
data: A dictionary or list from Salesforce API response
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary or list with transformed keys and filtered values
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
# remove the custom object indicator for display
|
||||
if "__c" in key:
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_value:
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_item:
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
|
||||
"""Convert a nested dictionary or list into a human-readable string format.
|
||||
|
||||
Recursively traverses the data structure and formats it with:
|
||||
- Key-value pairs on separate lines
|
||||
- Nested structures indented for readability
|
||||
- Lists and dictionaries handled with appropriate formatting
|
||||
|
||||
Args:
|
||||
data: The dictionary or list to convert
|
||||
indent: Number of spaces to indent (default: 0)
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the data structure
|
||||
"""
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent + 2))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _extract_dict_text(raw_dict: dict) -> str:
|
||||
"""Extract text from a Salesforce API response dictionary by:
|
||||
1. Cleaning the dictionary
|
||||
2. Converting the cleaned dictionary to natural language
|
||||
"""
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_for_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
|
||||
|
||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||
"""
|
||||
Checks if the field value is a child object.
|
||||
"""
|
||||
return (
|
||||
isinstance(field_value, OrderedDict)
|
||||
and "records" in field_value.keys()
|
||||
and isinstance(field_value["records"], list)
|
||||
and len(field_value["records"]) > 0
|
||||
and "Id" in field_value["records"][0].keys()
|
||||
)
|
||||
|
||||
|
||||
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||
"""
|
||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||
It also goes through the child objects and extracts them as Sections.
|
||||
"""
|
||||
top_level_dict = {}
|
||||
|
||||
child_object_sections = []
|
||||
for field_name, field_value in salesforce_object.items():
|
||||
# If the field value is not a child object, add it to the top level dict
|
||||
# to turn into text for the top level section
|
||||
if not _field_value_is_child_object(field_value):
|
||||
top_level_dict[field_name] = field_value
|
||||
continue
|
||||
|
||||
# If the field value is a child object, extract the child objects and add them as sections
|
||||
for record in field_value["records"]:
|
||||
child_object_id = record["Id"]
|
||||
child_object_sections.append(
|
||||
Section(
|
||||
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
|
||||
link=f"{base_url}/{child_object_id}",
|
||||
)
|
||||
)
|
||||
|
||||
top_level_id = salesforce_object["Id"]
|
||||
top_level_section = Section(
|
||||
text=_extract_dict_text(top_level_dict),
|
||||
link=f"{base_url}/{top_level_id}",
|
||||
)
|
||||
return [top_level_section, *child_object_sections]
|
||||
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pytz import UTC
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
if start is None or end is None:
|
||||
return ""
|
||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||
return (
|
||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
||||
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
||||
return sf_object.describe()
|
||||
|
||||
|
||||
def _is_valid_child_object(
|
||||
sf_client: Salesforce, child_relationship: dict[str, Any]
|
||||
) -> bool:
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
|
||||
child_object_types = set()
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if _is_valid_child_object(sf_client, child_relationship):
|
||||
logger.debug(
|
||||
f"Found valid child object {child_relationship['childSObject']}"
|
||||
)
|
||||
child_object_types.add(child_relationship["childSObject"])
|
||||
return child_object_types
|
||||
|
||||
|
||||
def _get_all_queryable_fields_of_sf_type(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
) -> list[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
compound_field_names: set[str] = set()
|
||||
for field in fields:
|
||||
if compound_field_name := field.get("compoundFieldName"):
|
||||
compound_field_names.add(compound_field_name)
|
||||
if field.get("type", "base64") == "base64":
|
||||
continue
|
||||
if field_name := field.get("name"):
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - compound_field_names)
|
||||
|
||||
|
||||
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
||||
"""
|
||||
Send a small query to check if the object type is empty so we don't
|
||||
perform extra bulk queries
|
||||
"""
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
if "OPERATION_TOO_LARGE" not in str(e):
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
||||
# Check if the csv already exists
|
||||
if os.path.exists(get_object_type_path(sf_type)):
|
||||
existing_csvs = [
|
||||
os.path.join(get_object_type_path(sf_type), f)
|
||||
for f in os.listdir(get_object_type_path(sf_type))
|
||||
if f.endswith(".csv")
|
||||
]
|
||||
# If the csv already exists, return the path
|
||||
# This is likely due to a previous run that failed
|
||||
# after downloading the csv but before the data was
|
||||
# written to the db
|
||||
if existing_csvs:
|
||||
return existing_csvs
|
||||
return None
|
||||
|
||||
|
||||
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
||||
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def _bulk_retrieve_from_salesforce(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
time_filter: str,
|
||||
) -> tuple[str, list[str] | None]:
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
||||
return sf_type, None
|
||||
|
||||
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||
return sf_type, existing_csvs
|
||||
|
||||
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
||||
|
||||
bulk_2_handler = SFBulk2Handler(
|
||||
session_id=sf_client.session_id,
|
||||
bulk2_url=sf_client.bulk2_url,
|
||||
proxies=sf_client.proxies,
|
||||
session=sf_client.session,
|
||||
)
|
||||
bulk_2_type = SFBulk2Type(
|
||||
object_name=sf_type,
|
||||
bulk2_url=bulk_2_handler.bulk2_url,
|
||||
headers=bulk_2_handler.headers,
|
||||
session=bulk_2_handler.session,
|
||||
)
|
||||
|
||||
logger.info(f"Downloading {sf_type}")
|
||||
logger.info(f"Query: {query}")
|
||||
|
||||
try:
|
||||
# This downloads the file to a file in the target path with a random name
|
||||
results = bulk_2_type.download(
|
||||
query=query,
|
||||
path=get_object_type_path(sf_type),
|
||||
max_records=1000000,
|
||||
)
|
||||
all_download_paths = [result["file"] for result in results]
|
||||
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
||||
return sf_type, all_download_paths
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
||||
return sf_type, None
|
||||
|
||||
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
object_types: set[str],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> dict[str, list[str] | None]:
|
||||
"""
|
||||
Fetches all the csvs in parallel for the given object types
|
||||
Returns a dict of (sf_type, full_download_path)
|
||||
"""
|
||||
time_filter = _build_time_filter_for_salesforce(start, end)
|
||||
time_filter_for_each_object_type = {}
|
||||
# We do this outside of the thread pool executor because this requires
|
||||
# a database connection and we don't want to block the thread pool
|
||||
# executor from running
|
||||
for sf_type in object_types:
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
if has_at_least_one_object_of_type(sf_type):
|
||||
time_filter_for_each_object_type[sf_type] = time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = ""
|
||||
|
||||
# Run the bulk retrieve in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = executor.map(
|
||||
lambda object_type: _bulk_retrieve_from_salesforce(
|
||||
sf_client=sf_client,
|
||||
sf_type=object_type,
|
||||
time_filter=time_filter_for_each_object_type[object_type],
|
||||
),
|
||||
object_types,
|
||||
)
|
||||
return dict(results)
|
||||
@@ -0,0 +1,209 @@
|
||||
import csv
|
||||
import shelve
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_child_to_parent_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_parent_to_child_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _update_relationship_shelves(
|
||||
child_id: str,
|
||||
parent_ids: set[str],
|
||||
) -> None:
|
||||
"""Update the relationship shelf when a record is updated."""
|
||||
try:
|
||||
# Convert child_id to string once
|
||||
str_child_id = str(child_id)
|
||||
|
||||
# First update child to parent mapping
|
||||
with shelve.open(
|
||||
get_child_to_parent_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as child_to_parent_db:
|
||||
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
|
||||
child_to_parent_db[str_child_id] = list(parent_ids)
|
||||
|
||||
# Calculate differences outside the next context manager
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Only sync once at the end
|
||||
child_to_parent_db.sync()
|
||||
|
||||
# Then update parent to child mapping in a single transaction
|
||||
if not parent_ids_to_remove and not parent_ids_to_add:
|
||||
return
|
||||
with shelve.open(
|
||||
get_parent_to_child_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as parent_to_child_db:
|
||||
# Process all removals first
|
||||
for parent_id in parent_ids_to_remove:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
if str_child_id in existing_children:
|
||||
existing_children.remove(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Then process all additions
|
||||
for parent_id in parent_ids_to_add:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
existing_children.add(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Single sync at the end
|
||||
parent_to_child_db.sync()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship shelves: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID.
|
||||
|
||||
Args:
|
||||
parent_id: The ID of the parent object
|
||||
|
||||
Returns:
|
||||
A set of child object IDs
|
||||
"""
|
||||
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
|
||||
return set(parent_to_child_db.get(parent_id, []))
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using shelve storage."""
|
||||
updated_ids = []
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
|
||||
# First read the CSV to get all the data
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
# Update relationship shelves for any parent references
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
_update_relationship_shelves(id, parent_ids)
|
||||
for field in field_to_remove:
|
||||
# We use this to extract the Primary Owner later
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update the main object shelf
|
||||
with shelve.open(shelf_path) as object_type_db:
|
||||
object_type_db[id] = row
|
||||
# Update the ID-to-type mapping shelf
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
id_type_db[id] = object_type
|
||||
|
||||
updated_ids.append(id)
|
||||
|
||||
# os.remove(csv_download_path)
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
# Look up the object type from the ID-to-type mapping
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
if object_id not in id_type_db:
|
||||
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
|
||||
return None
|
||||
return id_type_db[object_id]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""
|
||||
Retrieve the record and return it as a SalesforceObject.
|
||||
The object type will be looked up from the ID-to-type mapping shelf.
|
||||
"""
|
||||
if object_type is None:
|
||||
if not (object_type := get_type_from_id(object_id)):
|
||||
return None
|
||||
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
with shelve.open(shelf_path) as db:
|
||||
if object_id not in db:
|
||||
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
|
||||
return None
|
||||
data = db[object_id]
|
||||
return SalesforceObject(
|
||||
id=object_id,
|
||||
type=object_type,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""
|
||||
Find all object IDs for rows of the specified type.
|
||||
"""
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
try:
|
||||
with shelve.open(shelf_path) as db:
|
||||
return list(db.keys())
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: set[str], parent_types: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
|
||||
or have children in the updated_ids.
|
||||
|
||||
Args:
|
||||
updated_ids: List of IDs that were updated
|
||||
parent_types: List of object types to filter by
|
||||
|
||||
Returns:
|
||||
A dictionary of IDs that match the criteria
|
||||
"""
|
||||
affected_ids_by_type: dict[str, set[str]] = {}
|
||||
|
||||
# Check each updated ID
|
||||
for updated_id in updated_ids:
|
||||
# Add the ID itself if it's of a parent type
|
||||
updated_type = get_type_from_id(updated_id)
|
||||
if updated_type in parent_types:
|
||||
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
|
||||
continue
|
||||
|
||||
# Get parents of this ID and add them if they're of a parent type
|
||||
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
|
||||
parent_ids = child_to_parent_db.get(updated_id, [])
|
||||
for parent_id in parent_ids:
|
||||
parent_type = get_type_from_id(parent_id)
|
||||
if parent_type in parent_types:
|
||||
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
|
||||
|
||||
return affected_ids_by_type
|
||||
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
|
||||
def get_object_shelf_path(object_type: str) -> str:
|
||||
"""Get the path to the shelf file for a specific object type."""
|
||||
base_path = get_object_type_path(object_type)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
return os.path.join(base_path, "data.shelf")
|
||||
|
||||
|
||||
def get_id_type_shelf_path() -> str:
|
||||
"""Get the path to the ID-to-type mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_parent_to_child_shelf_path() -> str:
|
||||
"""Get the path to the parent-to-child mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_child_to_parent_shelf_path() -> str:
|
||||
"""Get the path to the child-to-parent mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")
|
||||
@@ -0,0 +1,737 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
get_affected_parent_ids_by_type,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
update_sf_db_with_csv,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
_VALID_SALESFORCE_IDS = [
|
||||
"001bm00000fd9Z3AAI",
|
||||
"001bm00000fdYTdAAM",
|
||||
"001bm00000fdYTeAAM",
|
||||
"001bm00000fdYTfAAM",
|
||||
"001bm00000fdYTgAAM",
|
||||
"001bm00000fdYThAAM",
|
||||
"001bm00000fdYTiAAM",
|
||||
"001bm00000fdYTjAAM",
|
||||
"001bm00000fdYTkAAM",
|
||||
"001bm00000fdYTlAAM",
|
||||
"001bm00000fdYTmAAM",
|
||||
"001bm00000fdYTnAAM",
|
||||
"001bm00000fdYToAAM",
|
||||
"500bm00000XoOxtAAF",
|
||||
"500bm00000XoOxuAAF",
|
||||
"500bm00000XoOxvAAF",
|
||||
"500bm00000XoOxwAAF",
|
||||
"500bm00000XoOxxAAF",
|
||||
"500bm00000XoOxyAAF",
|
||||
"500bm00000XoOxzAAF",
|
||||
"500bm00000XoOy0AAF",
|
||||
"500bm00000XoOy1AAF",
|
||||
"500bm00000XoOy2AAF",
|
||||
"500bm00000XoOy3AAF",
|
||||
"500bm00000XoOy4AAF",
|
||||
"500bm00000XoOy5AAF",
|
||||
"500bm00000XoOy6AAF",
|
||||
"500bm00000XoOy7AAF",
|
||||
"500bm00000XoOy8AAF",
|
||||
"500bm00000XoOy9AAF",
|
||||
"500bm00000XoOyAAAV",
|
||||
"500bm00000XoOyBAAV",
|
||||
"500bm00000XoOyCAAV",
|
||||
"500bm00000XoOyDAAV",
|
||||
"500bm00000XoOyEAAV",
|
||||
"500bm00000XoOyFAAV",
|
||||
"500bm00000XoOyGAAV",
|
||||
"500bm00000XoOyHAAV",
|
||||
"500bm00000XoOyIAAV",
|
||||
"003bm00000EjHCjAAN",
|
||||
"003bm00000EjHCkAAN",
|
||||
"003bm00000EjHClAAN",
|
||||
"003bm00000EjHCmAAN",
|
||||
"003bm00000EjHCnAAN",
|
||||
"003bm00000EjHCoAAN",
|
||||
"003bm00000EjHCpAAN",
|
||||
"003bm00000EjHCqAAN",
|
||||
"003bm00000EjHCrAAN",
|
||||
"003bm00000EjHCsAAN",
|
||||
"003bm00000EjHCtAAN",
|
||||
"003bm00000EjHCuAAN",
|
||||
"003bm00000EjHCvAAN",
|
||||
"003bm00000EjHCwAAN",
|
||||
"003bm00000EjHCxAAN",
|
||||
"003bm00000EjHCyAAN",
|
||||
"003bm00000EjHCzAAN",
|
||||
"003bm00000EjHD0AAN",
|
||||
"003bm00000EjHD1AAN",
|
||||
"003bm00000EjHD2AAN",
|
||||
"550bm00000EXc2tAAD",
|
||||
"006bm000006kyDpAAI",
|
||||
"006bm000006kyDqAAI",
|
||||
"006bm000006kyDrAAI",
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
"006bm000006kyDvAAI",
|
||||
"006bm000006kyDwAAI",
|
||||
"006bm000006kyDxAAI",
|
||||
"006bm000006kyDyAAI",
|
||||
"006bm000006kyDzAAI",
|
||||
"006bm000006kyE0AAI",
|
||||
"006bm000006kyE1AAI",
|
||||
"006bm000006kyE2AAI",
|
||||
"006bm000006kyE3AAI",
|
||||
"006bm000006kyE4AAI",
|
||||
"006bm000006kyE5AAI",
|
||||
"006bm000006kyE6AAI",
|
||||
"006bm000006kyE7AAI",
|
||||
"006bm000006kyE8AAI",
|
||||
"006bm000006kyE9AAI",
|
||||
"006bm000006kyEAAAY",
|
||||
"006bm000006kyEBAAY",
|
||||
"006bm000006kyECAAY",
|
||||
"006bm000006kyEDAAY",
|
||||
"006bm000006kyEEAAY",
|
||||
"006bm000006kyEFAAY",
|
||||
"006bm000006kyEGAAY",
|
||||
"006bm000006kyEHAAY",
|
||||
"006bm000006kyEIAAY",
|
||||
"006bm000006kyEJAAY",
|
||||
"005bm000009zy0TAAQ",
|
||||
"005bm000009zy25AAA",
|
||||
"005bm000009zy26AAA",
|
||||
"005bm000009zy28AAA",
|
||||
"005bm000009zy29AAA",
|
||||
"005bm000009zy2AAAQ",
|
||||
"005bm000009zy2BAAQ",
|
||||
]
|
||||
|
||||
|
||||
def clear_sf_db() -> None:
|
||||
"""
|
||||
Clears the SF DB by deleting all files in the data directory.
|
||||
"""
|
||||
shutil.rmtree(BASE_DATA_PATH)
|
||||
|
||||
|
||||
def create_csv_file(
|
||||
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||
) -> None:
|
||||
"""
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Get all unique fields from records
|
||||
fields: set[str] = set()
|
||||
for record in records:
|
||||
fields.update(record.keys())
|
||||
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||
|
||||
# Create CSV file
|
||||
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
for record in records:
|
||||
writer.writerow(record)
|
||||
|
||||
# Update the database with the CSV
|
||||
update_sf_db_with_csv(object_type, csv_path)
|
||||
|
||||
|
||||
def create_csv_with_example_data() -> None:
|
||||
"""
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[3],
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[4],
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[5],
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[6],
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[7],
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"FirstName": "John",
|
||||
"LastName": "Doe",
|
||||
"Email": "john.doe@acme.com",
|
||||
"Title": "CEO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[41],
|
||||
"FirstName": "Jane",
|
||||
"LastName": "Smith",
|
||||
"Email": "jane.smith@acme.com",
|
||||
"Title": "CTO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[42],
|
||||
"FirstName": "Bob",
|
||||
"LastName": "Johnson",
|
||||
"Email": "bob.j@globex.com",
|
||||
"Title": "Sales Director",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[43],
|
||||
"FirstName": "Sarah",
|
||||
"LastName": "Chen",
|
||||
"Email": "sarah.chen@techcorp.com",
|
||||
"Title": "Product Manager",
|
||||
"Phone": "415-555-0101",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[44],
|
||||
"FirstName": "Michael",
|
||||
"LastName": "Rodriguez",
|
||||
"Email": "m.rodriguez@biomed.com",
|
||||
"Title": "Research Director",
|
||||
"Phone": "617-555-0202",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[45],
|
||||
"FirstName": "Emily",
|
||||
"LastName": "Green",
|
||||
"Email": "emily.g@greenenergy.com",
|
||||
"Title": "Sustainability Lead",
|
||||
"Phone": "503-555-0303",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[46],
|
||||
"FirstName": "David",
|
||||
"LastName": "Kim",
|
||||
"Email": "david.kim@dataflow.com",
|
||||
"Title": "Data Scientist",
|
||||
"Phone": "206-555-0404",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[47],
|
||||
"FirstName": "Rachel",
|
||||
"LastName": "Taylor",
|
||||
"Email": "r.taylor@cloudnine.com",
|
||||
"Title": "Cloud Architect",
|
||||
"Phone": "303-555-0505",
|
||||
},
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"Name": "Acme Server Upgrade",
|
||||
"Amount": 50000,
|
||||
"Stage": "Prospecting",
|
||||
"CloseDate": "2024-06-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[63],
|
||||
"Name": "Globex Manufacturing Line",
|
||||
"Amount": 150000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-03-15",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[64],
|
||||
"Name": "Initech Software License",
|
||||
"Amount": 75000,
|
||||
"Stage": "Closed Won",
|
||||
"CloseDate": "2024-01-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[65],
|
||||
"Name": "TechCorp AI Implementation",
|
||||
"Amount": 250000,
|
||||
"Stage": "Needs Analysis",
|
||||
"CloseDate": "2024-08-15",
|
||||
"Probability": 60,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[66],
|
||||
"Name": "BioMed Lab Equipment",
|
||||
"Amount": 500000,
|
||||
"Stage": "Value Proposition",
|
||||
"CloseDate": "2024-09-30",
|
||||
"Probability": 75,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[67],
|
||||
"Name": "Green Energy Solar Project",
|
||||
"Amount": 750000,
|
||||
"Stage": "Proposal",
|
||||
"CloseDate": "2024-07-15",
|
||||
"Probability": 80,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[68],
|
||||
"Name": "DataFlow Analytics Platform",
|
||||
"Amount": 180000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-05-30",
|
||||
"Probability": 90,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[69],
|
||||
"Name": "Cloud Nine Infrastructure",
|
||||
"Amount": 300000,
|
||||
"Stage": "Qualification",
|
||||
"CloseDate": "2024-10-15",
|
||||
"Probability": 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create CSV files for each object type
|
||||
for object_type, records in example_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
"""
|
||||
Tests querying functionality by verifying:
|
||||
1. All expected Account IDs are found
|
||||
2. Each Account's data matches what was inserted
|
||||
"""
|
||||
# Expected test data for verification
|
||||
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||
_VALID_SALESFORCE_IDS[0]: {
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[1]: {
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[2]: {
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[3]: {
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[4]: {
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[5]: {
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[6]: {
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[7]: {
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
expected_accounts
|
||||
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||
assert set(account_ids) == set(
|
||||
expected_accounts.keys()
|
||||
), "Found account IDs don't match expected IDs"
|
||||
|
||||
# Verify each account's data
|
||||
for acc_id in account_ids:
|
||||
combined = get_record(acc_id)
|
||||
assert combined is not None, f"Could not find account {acc_id}"
|
||||
|
||||
expected = expected_accounts[acc_id]
|
||||
|
||||
# Verify account data matches
|
||||
for key, value in expected.items():
|
||||
value = str(value)
|
||||
assert (
|
||||
combined.data[key] == value
|
||||
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||
|
||||
print("All query tests passed successfully!")
|
||||
|
||||
|
||||
def test_upsert() -> None:
|
||||
"""
|
||||
Tests upsert functionality by:
|
||||
1. Updating an existing account
|
||||
2. Creating a new account
|
||||
3. Verifying both operations were successful
|
||||
"""
|
||||
# Create CSV for updating an existing account and adding a new one
|
||||
update_data: list[dict[str, str | int]] = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc. Updated",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
"Description": "Updated company info",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "New Company Inc.",
|
||||
"BillingCity": "Miami",
|
||||
"Industry": "Finance",
|
||||
"AnnualRevenue": 1000000,
|
||||
},
|
||||
]
|
||||
|
||||
create_csv_file("Account", update_data, "update_data.csv")
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||
assert updated_record is not None, "Updated record not found"
|
||||
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||
assert (
|
||||
updated_record.data["Description"] == "Updated company info"
|
||||
), "Description not added"
|
||||
|
||||
# Verify the new record was created
|
||||
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||
assert new_record is not None, "New record not found"
|
||||
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||
|
||||
print("All upsert tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationships() -> None:
|
||||
"""
|
||||
Tests relationship shelf updates and queries by:
|
||||
1. Creating test data with relationships
|
||||
2. Verifying the relationships are correctly stored
|
||||
3. Testing relationship queries
|
||||
"""
|
||||
# Create test data for each object type
|
||||
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||
"Case": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[13],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[14],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 2",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[48],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Test Opportunity",
|
||||
"Amount": 100000,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for each object type
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records, "relationship_test.csv")
|
||||
|
||||
# Test relationship queries
|
||||
# All these objects should be children of Acme Inc.
|
||||
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||
), "Opportunity not found in relationship"
|
||||
|
||||
# Test querying relationships for a different account (should be empty)
|
||||
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert (
|
||||
len(other_account_children) == 0
|
||||
), "Expected no children for different account"
|
||||
|
||||
print("All relationship tests passed successfully!")
|
||||
|
||||
|
||||
def test_account_with_children() -> None:
|
||||
"""
|
||||
Tests querying all accounts and retrieving their child objects.
|
||||
This test verifies that:
|
||||
1. All accounts can be retrieved
|
||||
2. Child objects are correctly linked
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
for account_id in account_ids:
|
||||
account = get_record(account_id)
|
||||
assert account is not None, f"Could not find account {account_id}"
|
||||
|
||||
# Get all child objects
|
||||
child_ids = get_child_ids(account_id)
|
||||
|
||||
# For Acme Inc., verify specific relationships
|
||||
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||
assert (
|
||||
len(child_ids) == 4
|
||||
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||
|
||||
# Get all child records
|
||||
child_records = []
|
||||
for child_id in child_ids:
|
||||
child_record = get_record(child_id)
|
||||
if child_record is not None:
|
||||
child_records.append(child_record)
|
||||
# Verify Cases
|
||||
cases = [r for r in child_records if r.type == "Case"]
|
||||
assert (
|
||||
len(cases) == 2
|
||||
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||
case_subjects = {case.data["Subject"] for case in cases}
|
||||
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||
|
||||
# Verify Contacts
|
||||
contacts = [r for r in child_records if r.type == "Contact"]
|
||||
assert (
|
||||
len(contacts) == 1
|
||||
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||
contact = contacts[0]
|
||||
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||
|
||||
# Verify Opportunities
|
||||
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||
assert (
|
||||
len(opportunities) == 1
|
||||
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||
opportunity = opportunities[0]
|
||||
assert (
|
||||
opportunity.data["Name"] == "Test Opportunity"
|
||||
), "Opportunity name mismatch"
|
||||
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||
|
||||
print("All account with children tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationship_updates() -> None:
|
||||
"""
|
||||
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||
This test verifies:
|
||||
1. Initial relationship is created correctly
|
||||
2. When parent reference is updated, old relationship is removed
|
||||
3. New relationship is created correctly
|
||||
"""
|
||||
# Create initial test data - Contact linked to Acme Inc.
|
||||
initial_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||
|
||||
# Verify initial relationship
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||
), "Initial relationship not created"
|
||||
|
||||
# Update contact to be linked to Globex Corp instead
|
||||
updated_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||
|
||||
# Verify old relationship is removed
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||
), "Old relationship not removed"
|
||||
|
||||
# Verify new relationship is created
|
||||
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||
|
||||
print("All relationship update tests passed successfully!")
|
||||
|
||||
|
||||
def test_get_affected_parent_ids() -> None:
|
||||
"""
|
||||
Tests get_affected_parent_ids functionality by verifying:
|
||||
1. IDs that are directly in the parent_types list are included
|
||||
2. IDs that have children in the updated_ids list are included
|
||||
3. IDs that are neither of the above are not included
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Parent Account 2",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Not Affected Account",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Child",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for test data
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||
), "Unaffected ID incorrectly included"
|
||||
|
||||
# Test Case 4: No matches
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Opportunity"] # Wrong type
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||
|
||||
print("All get_affected_parent_ids tests passed successfully!")
|
||||
|
||||
|
||||
def main_build() -> None:
|
||||
clear_sf_db()
|
||||
create_csv_with_example_data()
|
||||
test_query()
|
||||
test_upsert()
|
||||
test_relationships()
|
||||
test_account_with_children()
|
||||
test_relationship_updates()
|
||||
test_get_affected_parent_ids()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_build()
|
||||
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection(
|
||||
isolation_level: str | None = None,
|
||||
) -> Iterator[sqlite3.Connection]:
|
||||
"""Get a database connection with proper isolation level and error handling.
|
||||
|
||||
Args:
|
||||
isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||
"""
|
||||
# 60 second timeout for locks
|
||||
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
|
||||
|
||||
if isolation_level is not None:
|
||||
conn.isolation_level = isolation_level
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||
if os.path.exists(get_sqlite_db_path()):
|
||||
return
|
||||
|
||||
# Create database directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||
|
||||
with get_db_connection("EXCLUSIVE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
|
||||
# Main table for storing Salesforce objects
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS salesforce_objects (
|
||||
id TEXT PRIMARY KEY,
|
||||
object_type TEXT NOT NULL,
|
||||
data TEXT NOT NULL, -- JSON serialized data
|
||||
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# Table for parent-child relationships with covering index
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id)
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# New table for caching parent-child relationships with object types
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationship_types (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
parent_type TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id, parent_type)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Always recreate indexes to ensure they exist
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
|
||||
|
||||
# Create covering indexes for common queries
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_object_type
|
||||
ON salesforce_objects(object_type, id)
|
||||
WHERE object_type IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_parent_id
|
||||
ON relationships(parent_id, child_id)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_child_parent
|
||||
ON relationships(child_id)
|
||||
WHERE child_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# New composite index for fast parent type lookups
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_relationship_types_lookup
|
||||
ON relationship_types(parent_type, child_id, parent_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Analyze tables to help query planner
|
||||
cursor.execute("ANALYZE relationships")
|
||||
cursor.execute("ANALYZE salesforce_objects")
|
||||
cursor.execute("ANALYZE relationship_types")
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _update_relationship_tables(
|
||||
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
|
||||
) -> None:
|
||||
"""Update the relationship tables when a record is updated.
|
||||
|
||||
Args:
|
||||
conn: The database connection to use (must be in a transaction)
|
||||
child_id: The ID of the child record
|
||||
parent_ids: Set of parent IDs to link to
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get existing parent IDs
|
||||
cursor.execute(
|
||||
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
|
||||
)
|
||||
old_parent_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
# Calculate differences
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Remove old relationships
|
||||
if parent_ids_to_remove:
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
# Also remove from relationship_types
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
|
||||
# Add new relationships
|
||||
if parent_ids_to_add:
|
||||
# First add to relationships table
|
||||
cursor.executemany(
|
||||
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
|
||||
[(child_id, pid) for pid in parent_ids_to_add],
|
||||
)
|
||||
|
||||
# Then get the types of the parent objects and add to relationship_types
|
||||
for parent_id in parent_ids_to_add:
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
parent_type = result[0]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO relationship_types (child_id, parent_id, parent_type)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(child_id, parent_id, parent_type),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship tables: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||
updated_ids = []
|
||||
|
||||
# Use IMMEDIATE to get a write lock at the start of the transaction
|
||||
with get_db_connection("IMMEDIATE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if "Id" not in row:
|
||||
logger.warning(
|
||||
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||
)
|
||||
continue
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
|
||||
# Process relationships and clean data
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
|
||||
# Remove unwanted fields
|
||||
for field in field_to_remove:
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update main object data
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(id, object_type, json.dumps(row)),
|
||||
)
|
||||
|
||||
# Update relationships using the same connection
|
||||
_update_relationship_tables(conn, id, parent_ids)
|
||||
updated_ids.append(id)
|
||||
|
||||
conn.commit()
|
||||
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Force index usage with INDEXED BY
|
||||
cursor.execute(
|
||||
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
child_ids = {row[0] for row in cursor.fetchall()}
|
||||
return child_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if object_type is None:
|
||||
object_type = get_type_from_id(object_id)
|
||||
if not object_type:
|
||||
return None
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""Find all object IDs for rows of the specified type."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: list[str],
|
||||
parent_types: list[str],
|
||||
batch_size: int = 500,
|
||||
) -> Iterator[tuple[str, set[str]]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the
|
||||
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
|
||||
"""
|
||||
# SQLite typically has a limit of 999 variables
|
||||
updated_ids_batches = batch_list(updated_ids, batch_size)
|
||||
updated_parent_ids: set[str] = set()
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for batch_ids in updated_ids_batches:
|
||||
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||
|
||||
for parent_type in parent_types:
|
||||
affected_ids: set[str] = set()
|
||||
|
||||
# Get directly updated objects of parent types - using index on object_type
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT id FROM salesforce_objects
|
||||
WHERE id IN ({id_placeholders})
|
||||
AND object_type = ?
|
||||
""",
|
||||
batch_ids + [parent_type],
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Get parent objects of updated objects - using optimized relationship_types table
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT DISTINCT parent_id
|
||||
FROM relationship_types
|
||||
INDEXED BY idx_relationship_types_lookup
|
||||
WHERE parent_type = ?
|
||||
AND child_id IN ({id_placeholders})
|
||||
""",
|
||||
[parent_type] + batch_ids,
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Remove any parent IDs that have already been processed
|
||||
new_affected_ids = affected_ids - updated_parent_ids
|
||||
# Add the new affected IDs to the set of updated parent IDs
|
||||
updated_parent_ids.update(new_affected_ids)
|
||||
|
||||
if new_affected_ids:
|
||||
yield parent_type, new_affected_ids
|
||||
|
||||
|
||||
def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||
"""Check if there is at least one object of the specified type in the database.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type to check
|
||||
|
||||
Returns:
|
||||
bool: True if at least one object exists, False otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
|
||||
(object_type,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
@@ -1,66 +1,72 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
id: str
|
||||
type: str
|
||||
data: dict[str, Any]
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ID": self.id,
|
||||
"Type": self.type,
|
||||
"Data": self.data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||
return cls(
|
||||
id=data["Id"],
|
||||
type=data["Type"],
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
# This defines the base path for all data files relative to this file
|
||||
# AKA BE CAREFUL WHEN MOVING THIS FILE
|
||||
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
def get_sqlite_db_path() -> str:
|
||||
"""Get the path to the sqlite db file."""
|
||||
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
os.makedirs(type_dir, exist_ok=True)
|
||||
return type_dir
|
||||
|
||||
|
||||
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
|
||||
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
|
||||
|
||||
|
||||
def validate_salesforce_id(salesforce_id: str) -> bool:
|
||||
"""Validate the checksum portion of an 18-character Salesforce ID.
|
||||
|
||||
Args:
|
||||
salesforce_id: An 18-character Salesforce ID
|
||||
|
||||
Returns:
|
||||
bool: True if the checksum is valid, False otherwise
|
||||
"""
|
||||
if len(salesforce_id) != 18:
|
||||
return False
|
||||
|
||||
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
|
||||
|
||||
checksum = salesforce_id[15:18]
|
||||
calculated_checksum = ""
|
||||
|
||||
for chunk in chunks:
|
||||
result_string = "".join(
|
||||
"1" if char.isupper() else "0" for char in reversed(chunk)
|
||||
)
|
||||
calculated_checksum += _LOOKUP[result_string]
|
||||
|
||||
return checksum == calculated_checksum
|
||||
|
||||
@@ -264,24 +264,6 @@ class SlackTextCleaner:
|
||||
message = message.replace("<!everyone>", "@everyone")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_links(message: str) -> str:
|
||||
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
|
||||
# Find user IDs in the message
|
||||
possible_link_matches = re.findall(r"<(.*?)>", message)
|
||||
for possible_link in possible_link_matches:
|
||||
if not possible_link:
|
||||
continue
|
||||
# Special slack patterns that aren't for links
|
||||
if possible_link[0] not in ["#", "@", "!"]:
|
||||
link_display = (
|
||||
possible_link
|
||||
if "|" not in possible_link
|
||||
else possible_link.split("|")[1]
|
||||
)
|
||||
message = message.replace(f"<{possible_link}>", link_display)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_catchall(message: str) -> str:
|
||||
"""Replaces pattern of <!something|another-thing> with another-thing
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
|
||||
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
|
||||
|
||||
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
|
||||
# Explicitly check if running in multi-tenant mode to prevent potential security risks
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Upload input for web connector is not supported in cloud environments"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"This is not a UI supported Web Connector flow, "
|
||||
"are you sure you want to do this?"
|
||||
|
||||
@@ -40,6 +40,13 @@ class ZendeskClient:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -96,6 +96,8 @@ class Tag(BaseModel):
|
||||
class BaseFilters(BaseModel):
|
||||
source_type: list[DocumentSource] | None = None
|
||||
document_set: list[str] | None = None
|
||||
user_folders: list[str] | None = None
|
||||
document_ids: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
|
||||
|
||||
@@ -54,9 +54,11 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
|
||||
@@ -141,14 +141,20 @@ def get_valid_messages_from_query_sessions(
|
||||
return {row.chat_session_id: row.message for row in first_messages}
|
||||
|
||||
|
||||
# Retrieves chat sessions by user
|
||||
# Chat sessions do not include onyxbot flows
|
||||
def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -310,6 +316,23 @@ def update_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def delete_all_chat_sessions_for_user(
|
||||
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
query = db_session.query(ChatSession).filter(
|
||||
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
query.delete(synchronize_session=False)
|
||||
else:
|
||||
query.update({ChatSession.deleted: True}, synchronize_session=False)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -90,15 +91,22 @@ def get_connector_credential_pairs(
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
) # noqa
|
||||
)
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -310,6 +318,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
# DefaultCCPair has id 1 since it is the first CC pair created
|
||||
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
|
||||
# auto-incrementing id
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
@@ -350,7 +361,12 @@ def add_credential_to_connector(
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
@@ -427,7 +443,12 @@ def remove_credential_from_connector(
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
@@ -86,7 +86,7 @@ def _add_user_filters(
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
- if the user is a curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
@@ -97,6 +97,7 @@ def _add_user_filters(
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
@@ -152,10 +153,16 @@ def fetch_credential_by_id(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
assume_admin=assume_admin,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -10,6 +12,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -49,28 +55,87 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
# Global so we don't create more than one engine per process
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
|
||||
def get_iam_auth_token(
|
||||
host: str, port: str, user: str, region: str = "us-east-2"
|
||||
) -> str:
|
||||
"""
|
||||
Generate an IAM authentication token using boto3.
|
||||
"""
|
||||
client = boto3.client("rds", region_name=region)
|
||||
token = client.generate_db_auth_token(
|
||||
DBHostname=host, Port=int(port), DBUsername=user
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def configure_psycopg2_iam_auth(
|
||||
cparams: dict[str, Any], host: str, port: str, user: str, region: str
|
||||
) -> None:
|
||||
"""
|
||||
Configure cparams for psycopg2 with IAM token and SSL.
|
||||
"""
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
cparams["password"] = token
|
||||
cparams["sslmode"] = "require"
|
||||
cparams["sslrootcert"] = SSL_CERT_FILE
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
use_iam: bool = USE_IAM_AUTH,
|
||||
region: str = "us-west-2",
|
||||
) -> str:
|
||||
if use_iam:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||
else:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
# For asyncpg, do not include application_name in the connection string
|
||||
if app_name and db_api != "asyncpg":
|
||||
if "?" in base_conn_str:
|
||||
return f"{base_conn_str}&application_name={app_name}"
|
||||
else:
|
||||
return f"{base_conn_str}?application_name={app_name}"
|
||||
return base_conn_str
|
||||
|
||||
|
||||
if LOG_POSTGRES_LATENCY:
|
||||
# Function to log before query execution
|
||||
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
def before_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
conn.info["query_start_time"] = time.time()
|
||||
|
||||
# Function to log after query execution
|
||||
@event.listens_for(Engine, "after_cursor_execute")
|
||||
def after_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
total_time = time.time() - conn.info["query_start_time"]
|
||||
# don't spam TOO hard
|
||||
if total_time > 0.1:
|
||||
logger.debug(
|
||||
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
||||
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
|
||||
|
||||
|
||||
if LOG_POSTGRES_CONN_COUNTS:
|
||||
# Global counter for connection checkouts and checkins
|
||||
checkout_count = 0
|
||||
checkin_count = 0
|
||||
|
||||
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
|
||||
logger.debug(f"Total connection checkins: {checkin_count}")
|
||||
|
||||
|
||||
"""END DEBUGGING LOGGING"""
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
"""Get the current time from Postgres representing the start of the transaction
|
||||
Within the same transaction this value will not update
|
||||
This datetime object returned should be timezone aware, default Postgres timezone is UTC
|
||||
"""
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result is None:
|
||||
raise ValueError("Database did not return a time")
|
||||
return result
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
# Default parameters for engine creation
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
@@ -145,33 +194,27 @@ class SqlEngine:
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
"""Private helper method to create and return an Engine."""
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
return create_engine(connection_string, **merged_kwargs)
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different Celery workers and tasks
|
||||
need different settings.
|
||||
"""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!
|
||||
"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
@@ -180,12 +223,10 @@ class SqlEngine:
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
"""Class method to set the app name."""
|
||||
cls._app_name = app_name
|
||||
|
||||
@classmethod
|
||||
def get_app_name(cls) -> str:
|
||||
"""Class method to get current app name."""
|
||||
if not cls._app_name:
|
||||
return ""
|
||||
return cls._app_name
|
||||
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
async def get_async_connection() -> Any:
|
||||
"""
|
||||
Custom connection function for async engine when using IAM auth.
|
||||
"""
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
|
||||
)
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
app_name = SqlEngine.get_app_name() + "_async"
|
||||
connection_string = build_connection_string(
|
||||
db_api=ASYNC_DB_API,
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
connect_args: dict[str, Any] = {}
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {
|
||||
"application_name": SqlEngine.get_app_name() + "_async"
|
||||
}
|
||||
},
|
||||
# async engine is only used by API server, so we can use those values
|
||||
# here as well
|
||||
connect_args=connect_args,
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
|
||||
def provide_iam_token_async(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
# For async engine using asyncpg, we still need to set the IAM token here.
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
# Dependency to get the current tenant ID
|
||||
# If no token is present, uses the default schema for this use case
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
try:
|
||||
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Set the search_path to the tenant's schema
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
# You can choose to re-raise the exception or handle it
|
||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||
raise
|
||||
else:
|
||||
yield session
|
||||
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
@@ -349,7 +397,6 @@ def get_session_with_tenant(
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
@@ -357,27 +404,20 @@ def get_session_with_tenant(
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
@@ -390,21 +430,17 @@ def get_session_with_tenant(
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
|
||||
|
||||
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
"""Get a session factory."""
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
@@ -489,3 +518,13 @@ async def warm_up_connections(
|
||||
await async_conn.execute(text("SELECT 1"))
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.models import ChatFolder
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_user_folders(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[ChatFolder]:
|
||||
return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all()
|
||||
|
||||
|
||||
def update_folder_display_priority(
|
||||
user_id: UUID | None,
|
||||
display_priority_map: dict[int, int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
folders = get_user_folders(user_id=user_id, db_session=db_session)
|
||||
folder_ids = {folder.id for folder in folders}
|
||||
if folder_ids != set(display_priority_map.keys()):
|
||||
raise ValueError("Invalid Folder IDs provided")
|
||||
|
||||
for folder in folders:
|
||||
folder.display_priority = display_priority_map[folder.id]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_folder_by_id(
|
||||
user_id: UUID | None,
|
||||
folder_id: int,
|
||||
db_session: Session,
|
||||
) -> ChatFolder:
|
||||
folder = (
|
||||
db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none()
|
||||
)
|
||||
if not folder:
|
||||
raise ValueError("Folder by specified id does not exist")
|
||||
|
||||
if folder.user_id != user_id:
|
||||
raise PermissionError(f"Folder does not belong to user: {user_id}")
|
||||
|
||||
return folder
|
||||
|
||||
|
||||
def create_folder(
|
||||
user_id: UUID | None, folder_name: str | None, db_session: Session
|
||||
) -> int:
|
||||
new_folder = ChatFolder(
|
||||
user_id=user_id,
|
||||
name=folder_name,
|
||||
)
|
||||
db_session.add(new_folder)
|
||||
db_session.commit()
|
||||
|
||||
return new_folder.id
|
||||
|
||||
|
||||
def rename_folder(
|
||||
user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
folder.name = folder_name
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_chat_to_folder(
|
||||
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_session.folder_id = folder.id
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_chat_from_folder(
|
||||
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.folder_id != folder.id:
|
||||
raise ValueError("The chat session is not in the specified folder.")
|
||||
|
||||
if folder.user_id != user_id:
|
||||
raise ValueError(
|
||||
f"Tried to remove a chat session from a folder that does not below to "
|
||||
f"this user, user id: {user_id}"
|
||||
)
|
||||
|
||||
chat_session.folder_id = None
|
||||
if chat_session in folder.chat_sessions:
|
||||
folder.chat_sessions.remove(chat_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_folder(
|
||||
user_id: UUID | None,
|
||||
folder_id: int,
|
||||
including_chats: bool,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
folder = get_folder_by_id(
|
||||
user_id=user_id, folder_id=folder_id, db_session=db_session
|
||||
)
|
||||
|
||||
# Assuming there will not be a massive number of chats in any given folder
|
||||
if including_chats:
|
||||
for chat_session in folder.chat_sessions:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=chat_session.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.delete(folder)
|
||||
db_session.commit()
|
||||
99
backend/onyx/db/milestone.py
Normal file
99
backend/onyx/db/milestone.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.models import Milestone
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
USER_ASSISTANT_PREFIX = "user_assistants_used_"
|
||||
MULTI_ASSISTANT_USED = "multi_assistant_used"
|
||||
|
||||
|
||||
def create_milestone(
|
||||
user: User | None,
|
||||
event_type: MilestoneRecordType,
|
||||
db_session: Session,
|
||||
) -> Milestone:
|
||||
milestone = Milestone(
|
||||
event_type=event_type,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(milestone)
|
||||
db_session.commit()
|
||||
|
||||
return milestone
|
||||
|
||||
|
||||
def create_milestone_if_not_exists(
|
||||
user: User | None, event_type: MilestoneRecordType, db_session: Session
|
||||
) -> tuple[Milestone, bool]:
|
||||
# Check if it exists
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if milestone is not None:
|
||||
return milestone, False
|
||||
|
||||
# If it doesn't exist, try to create it.
|
||||
try:
|
||||
milestone = create_milestone(user, event_type, db_session)
|
||||
return milestone, True
|
||||
except IntegrityError:
|
||||
# Another thread or process inserted it in the meantime
|
||||
db_session.rollback()
|
||||
# Fetch again to return the existing record
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one() # Now should exist
|
||||
return milestone, False
|
||||
|
||||
|
||||
def update_user_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
user_id: str | None,
|
||||
assistant_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
milestone.event_tracker = event_tracker = {}
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
# No need to keep tracking and populating if the milestone has already been hit
|
||||
return
|
||||
|
||||
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
|
||||
|
||||
if event_tracker.get(user_key) is None:
|
||||
event_tracker[user_key] = [assistant_id]
|
||||
elif assistant_id not in event_tracker[user_key]:
|
||||
event_tracker[user_key].append(assistant_id)
|
||||
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_multi_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
db_session: Session,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Returns if the milestone was hit and if it was just hit for the first time"""
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
return False, False
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
return True, False
|
||||
|
||||
for key, value in event_tracker.items():
|
||||
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
|
||||
event_tracker[MULTI_ASSISTANT_USED] = True
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
return True, True
|
||||
|
||||
return False, False
|
||||
@@ -5,6 +5,8 @@ from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -37,7 +39,7 @@ from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -52,6 +54,7 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.pydantic_type import PydanticType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
@@ -63,6 +66,8 @@ from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
@@ -70,6 +75,8 @@ class Base(DeclarativeBase):
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -84,6 +91,8 @@ class EncryptedString(TypeDecorator):
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -100,11 +109,76 @@ class EncryptedJson(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
impl = String
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is not None and "\x00" in value:
|
||||
logger.warning(f"NUL characters found in value: {value}")
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
"""
|
||||
Auth/Authz (users, permissions, access) Tables
|
||||
"""
|
||||
|
||||
|
||||
class UserFolder(Base):
|
||||
__tablename__ = "user_folder"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
parent_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_folder.id"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="folders")
|
||||
parent: Mapped["UserFolder"] = relationship(
|
||||
remote_side=[id], back_populates="children"
|
||||
)
|
||||
children: Mapped[list["UserFolder"]] = relationship(back_populates="parent")
|
||||
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(back_populates="folder")
|
||||
|
||||
|
||||
class UserDocument(str, Enum):
|
||||
CHAT = "chat"
|
||||
RECENT = "recent"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class UserFile(Base):
|
||||
__tablename__ = "user_file"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
parent_folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
ccpair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=False
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="files")
|
||||
folder: Mapped["UserFolder"] = relationship(back_populates="files")
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
# even an almost empty token from keycloak will not fit the default 1024 bytes
|
||||
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
|
||||
@@ -154,9 +228,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
)
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
|
||||
@@ -174,6 +245,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder", back_populates="user"
|
||||
)
|
||||
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
@@ -449,16 +525,16 @@ class Document(Base):
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
@@ -974,7 +1050,7 @@ class ChatSession(Base):
|
||||
default=ChatSessionSharedStatus.PRIVATE,
|
||||
)
|
||||
folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_folder.id"), nullable=True
|
||||
ForeignKey("user_folder.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
|
||||
@@ -1004,11 +1080,11 @@ class ChatSession(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
folder: Mapped["UserFolder"] = relationship(
|
||||
"UserFolder", back_populates="chat_sessions"
|
||||
)
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session"
|
||||
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
|
||||
)
|
||||
persona: Mapped["Persona"] = relationship("Persona")
|
||||
|
||||
@@ -1076,6 +1152,8 @@ class ChatMessage(Base):
|
||||
"SearchDoc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
@@ -1091,33 +1169,6 @@ class ChatMessage(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
"""For organizing chat sessions"""
|
||||
|
||||
__tablename__ = "chat_folder"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Only null if auth is off
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_folders")
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="folder"
|
||||
)
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, ChatFolder):
|
||||
return NotImplemented
|
||||
if self.display_priority == other.display_priority:
|
||||
# Bigger ID (created later) show earlier
|
||||
return self.id > other.id
|
||||
return self.display_priority < other.display_priority
|
||||
|
||||
|
||||
"""
|
||||
Feedback, Logging, Metrics Tables
|
||||
"""
|
||||
@@ -1344,6 +1395,11 @@ class StarterMessage(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class StarterMessageModel(BaseModel):
|
||||
name: str
|
||||
message: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
@@ -1534,6 +1590,32 @@ class SlackBot(Base):
|
||||
)
|
||||
|
||||
|
||||
class Milestone(Base):
|
||||
# This table is used to track significant events for a deployment towards finding value
|
||||
# The table is currently not used for features but it may be used in the future to inform
|
||||
# users about the product features and encourage usage/exploration.
|
||||
__tablename__ = "milestone"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
event_type: Mapped[MilestoneRecordType] = mapped_column(String)
|
||||
# Need to track counts and specific ids of certain events to know if the Milestone has been reached
|
||||
event_tracker: Mapped[dict | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User")
|
||||
|
||||
__table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),)
|
||||
|
||||
|
||||
class TaskQueueState(Base):
|
||||
# Currently refers to Celery Tasks
|
||||
__tablename__ = "task_queue_jobs"
|
||||
|
||||
36
backend/onyx/db/my_documents.py
Normal file
36
backend/onyx/db/my_documents.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
|
||||
CHAT_FOLDER_ID = -1
|
||||
RECENT_DOCUMENTS_FOLDER_ID = -2
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> FileUploadResponse:
|
||||
upload_response = upload_files(files, db_session)
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
new_file = UserFile(
|
||||
user_id=user.id if user else None,
|
||||
parent_folder_id=folder_id,
|
||||
file_id=file_path,
|
||||
document_id=file_path, # We'll use the same ID for now
|
||||
name=file.filename,
|
||||
)
|
||||
db_session.add(new_file)
|
||||
|
||||
db_session.commit()
|
||||
return upload_response
|
||||
|
||||
|
||||
# def trigger_document_indexing(db_session: Session, user_id: int) -> None:
|
||||
@@ -543,6 +543,10 @@ def upsert_persona(
|
||||
if tools is not None:
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
# We should only update display priority if it is not already set
|
||||
if existing_persona.display_priority is None:
|
||||
existing_persona.display_priority = display_priority
|
||||
|
||||
persona = existing_persona
|
||||
|
||||
else:
|
||||
|
||||
29
backend/onyx/db/user_documents.py
Normal file
29
backend/onyx/db/user_documents.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> FileUploadResponse:
|
||||
upload_response = upload_files(files, db_session)
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
new_file = UserFile(
|
||||
user_id=user.id if user else None,
|
||||
parent_folder_id=folder_id if folder_id != -1 else None,
|
||||
file_id=file_path,
|
||||
document_id=file_path,
|
||||
name=file.filename,
|
||||
)
|
||||
db_session.add(new_file)
|
||||
db_session.commit()
|
||||
return upload_response
|
||||
0
backend/onyx/db/user_file.py
Normal file
0
backend/onyx/db/user_file.py
Normal file
@@ -7,8 +7,15 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
@@ -185,3 +192,43 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
remaining_user_email
|
||||
for remaining_user_email in user_emails
|
||||
if remaining_user_email != user_to_delete.email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
@@ -369,6 +369,19 @@ class AdminCapable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIndex(
|
||||
Verifiable,
|
||||
Indexable,
|
||||
@@ -376,6 +389,7 @@ class BaseIndex(
|
||||
Deletable,
|
||||
AdminCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -112,6 +112,11 @@ schema DANSWER_CHUNK_NAME {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field user_folders type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
@@ -218,4 +223,10 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile random_ {
|
||||
first-phase {
|
||||
expression: random.match
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
<resource-limits>
|
||||
<!-- Default is 75% but this can be increased for Dockerized deployments -->
|
||||
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
|
||||
<disk>0.75</disk>
|
||||
<disk>0.85</disk>
|
||||
</resource-limits>
|
||||
</tuning>
|
||||
<engine>
|
||||
|
||||
@@ -16,6 +16,12 @@ logger = setup_logger()
|
||||
CONTENT_SUMMARY = "content_summary"
|
||||
|
||||
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
def _retryable_http_delete(http_client: httpx.Client, url: str) -> None:
|
||||
res = http_client.delete(url)
|
||||
res.raise_for_status()
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _delete_vespa_doc_chunks(
|
||||
document_id: str, index_name: str, http_client: httpx.Client
|
||||
@@ -28,10 +34,10 @@ def _delete_vespa_doc_chunks(
|
||||
|
||||
for chunk_id in doc_chunk_ids:
|
||||
try:
|
||||
res = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
|
||||
_retryable_http_delete(
|
||||
http_client,
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}",
|
||||
)
|
||||
res.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Failed to delete chunk, details: {e.response.text}")
|
||||
raise
|
||||
|
||||
@@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
@@ -312,6 +313,7 @@ class VespaIndex(DocumentIndex):
|
||||
with updating the associated permissions. Assumes that a document will not be split into
|
||||
multiple chunk batches calling this function multiple times, otherwise only the last set of
|
||||
chunks will be kept"""
|
||||
|
||||
# IMPORTANT: This must be done one index at a time, do not use secondary index here
|
||||
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
|
||||
|
||||
@@ -534,7 +536,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -545,8 +547,12 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}"
|
||||
)
|
||||
logger.debug(f'update_single PUT on URL "{vespa_url}"')
|
||||
resp = http_client.put(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
|
||||
vespa_url,
|
||||
params=params,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
@@ -618,7 +624,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -629,8 +635,12 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}"
|
||||
)
|
||||
logger.debug(f'delete_single DELETE on URL "{vespa_url}"')
|
||||
resp = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
|
||||
vespa_url,
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -697,6 +707,8 @@ class VespaIndex(DocumentIndex):
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
print("filters", filters)
|
||||
print("filters.user_folders", filters.__dict__)
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
target_hits = max(10 * num_to_retrieve, 1000)
|
||||
@@ -903,6 +915,32 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
logger.info("Batch deletion completed")
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters using Vespa's random ranking
|
||||
|
||||
This method is currently used for random chunk retrieval in the context of
|
||||
assistant starter message creation (passed as sample context for usage by the assistant).
|
||||
"""
|
||||
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
|
||||
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
|
||||
|
||||
random_seed = random.randint(0, 1000000)
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"hits": num_to_retrieve,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
"ranking.profile": "random_",
|
||||
"ranking.properties.random.seed": random_seed,
|
||||
}
|
||||
|
||||
return query_vespa(params)
|
||||
|
||||
|
||||
class _VespaDeleteRequest:
|
||||
def __init__(self, document_id: str, index_name: str) -> None:
|
||||
|
||||
@@ -64,10 +64,10 @@ def _does_document_exist(
|
||||
if doc_fetch_response.status_code != 200:
|
||||
logger.debug(f"Failed to check for document with URL {doc_url}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
f"Index name: {index_name}"
|
||||
f"Doc chunk id: {doc_chunk_id}"
|
||||
f"Unexpected fetch document by ID value from Vespa: "
|
||||
f"error={doc_fetch_response.status_code} "
|
||||
f"index={index_name} "
|
||||
f"doc_chunk_id={doc_chunk_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -67,5 +67,5 @@ def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
|
||||
else None,
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=True,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
@@ -9,17 +9,24 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_IDS
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_FOLDERS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
@@ -72,12 +79,20 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
||||
tags = filters.tags
|
||||
if tags:
|
||||
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
|
||||
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
filter_str += _build_or_filters(USER_FOLDERS, filters.user_folders)
|
||||
|
||||
filter_str += _build_or_filters(DOCUMENT_IDS, filters.document_ids)
|
||||
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5] # We remove the trailing " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ EMBEDDINGS = "embeddings"
|
||||
TITLE_EMBEDDING = "title_embedding"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
USER_FOLDERS = "user_folders"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
@@ -260,6 +260,21 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
# Remove any NUL characters from title/semantic_id
|
||||
# This is a known issue with the Zendesk connector
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if document.title:
|
||||
document.title = document.title.replace("\x00", "")
|
||||
if document.semantic_identifier:
|
||||
document.semantic_identifier = document.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
# Remove NUL characters from all sections
|
||||
for section in document.sections:
|
||||
if section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
|
||||
@@ -266,18 +266,27 @@ class DefaultMultiLLM(LLM):
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# Create a dictionary for model-specific arguments if it's None
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
# Specifically pass in "vertex_credentials" as a model_kwarg to the
|
||||
# completion call for vertex AI. More details here:
|
||||
# https://docs.litellm.ai/docs/providers/vertex
|
||||
vertex_credentials_key = "vertex_credentials"
|
||||
vertex_credentials = custom_config.get(vertex_credentials_key)
|
||||
if vertex_credentials and model_provider == "vertex_ai":
|
||||
model_kwargs[vertex_credentials_key] = vertex_credentials
|
||||
else:
|
||||
# standard case
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
@@ -453,7 +462,9 @@ class DefaultMultiLLM(LLM):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
if (
|
||||
DISABLE_LITELLM_STREAMING or self.config.model_name == "o1-2024-12-17"
|
||||
): # TODO: remove once litellm supports streaming
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
|
||||
@@ -28,6 +28,7 @@ from litellm.exceptions import RateLimitError # type: ignore
|
||||
from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -45,10 +46,19 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
|
||||
e: Exception,
|
||||
llm: LLM,
|
||||
fallback_to_error_msg: bool = False,
|
||||
custom_error_msg_mappings: dict[str, str]
|
||||
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
||||
) -> str:
|
||||
error_msg = str(e)
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
elif isinstance(e, AuthenticationError):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user