Compare commits

..

1 Commits

Author SHA1 Message Date
Weves
66779d5692 Add DB_READONLY_PASSWORD option in helm chart 2025-06-19 17:23:35 -07:00
693 changed files with 17166 additions and 37946 deletions

2
.github/CODEOWNERS vendored
View File

@@ -1,3 +1 @@
* @onyx-dot-app/onyx-core-team
# Helm charts Owners
/helm/ @justin-tahara

View File

@@ -1,40 +0,0 @@
name: Release Onyx Helm Charts
on:
push:
branches:
- main
permissions: write-all
jobs:
release:
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Configure Git
run: |
git config user.name "$GITHUB_ACTOR"
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
- name: Install Helm
uses: azure/setup-helm@v4
with:
version: v3.12.1
- name: Add Required Helm Repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo update
- name: Run chart-releaser
uses: helm/chart-releaser-action@v1.7.0
env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@@ -1,94 +0,0 @@
name: External Dependency Unit Tests
on:
merge_group:
pull_request:
branches: [main]
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
discover-test-dirs:
runs-on: ubuntu-latest
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Discover test directories
id: set-matrix
run: |
# Find all subdirectories in backend/tests/external_dependency_unit
dirs=$(find backend/tests/external_dependency_unit -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-dirs=$dirs" >> $GITHUB_OUTPUT
external-dependency-unit-tests:
needs: discover-test-dirs
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
- name: Run migrations
run: |
cd backend
alembic upgrade head
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/external_dependency_unit/${{ matrix.test-dir }}

View File

@@ -16,9 +16,6 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
@@ -269,9 +266,6 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -1,38 +0,0 @@
name: PR Labeler
on:
pull_request_target:
branches:
- main
types:
- opened
- reopened
- synchronize
- edited
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:
runs-on: ubuntu-latest
steps:
- name: Check PR title for Conventional Commits
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
echo "PR Title: $PR_TITLE"
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
Please update your PR title to follow the Conventional Commits style.
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
**Here are some examples of valid PR titles:**
- feat: add user authentication
- fix(login): handle null password error
- docs(readme): update installation instructions"
exit 1
fi

View File

@@ -16,9 +16,6 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
@@ -204,9 +201,6 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -47,7 +47,7 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--package-name onyx_openapi_client
- name: Run MyPy
run: |

View File

@@ -16,13 +16,12 @@ env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
@@ -50,15 +49,6 @@ env:
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Hubspot
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
# IMAP
IMAP_HOST: ${{ secrets.IMAP_HOST }}
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}

View File

@@ -45,9 +45,8 @@ PYTHONPATH=../backend
PYTHONUNBUFFERED=1
# Internet Search
# Internet Search
BING_API_KEY=<REPLACE THIS>
EXA_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
@@ -59,9 +58,3 @@ AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ran
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
# S3 File Store Configuration (MinIO for local development)
S3_ENDPOINT_URL=http://localhost:9004
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
S3_AWS_ACCESS_KEY_ID=minioadmin
S3_AWS_SECRET_ACCESS_KEY=minioadmin

View File

@@ -24,8 +24,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
@@ -46,8 +46,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
@@ -226,66 +226,35 @@
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery docfetching",
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
"group": "2"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
@@ -334,6 +303,35 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
@@ -428,7 +426,7 @@
},
"args": [
"--filename",
"generated/openapi.json"
"generated/openapi.json",
]
},
{

View File

@@ -59,7 +59,6 @@ Onyx being a fully functional app, relies on some external software, specificall
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [MinIO](https://min.io/) (File Store)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
@@ -172,10 +171,10 @@ Otherwise, you can follow the instructions below to run the application for deve
You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)

View File

@@ -37,7 +37,8 @@ RUN apt-get update && \
pkg-config \
gcc \
nano \
vim && \
vim \
postgresql-client && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -77,9 +78,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Install postgresql-client for easy manual tests
# Install it here to avoid it being cleaned up above
RUN apt-get update && apt-get install -y postgresql-client
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \

View File

@@ -20,44 +20,3 @@ To run all un-applied migrations:
To undo migrations:
`alembic downgrade -X`
where X is the number of migrations you want to undo from the current state
### Multi-tenant migrations
For multi-tenant deployments, you can use additional options:
**Upgrade all tenants:**
```bash
alembic -x upgrade_all_tenants=true upgrade head
```
**Upgrade specific schemas:**
```bash
# Single schema
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head
# Multiple schemas (comma-separated)
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head
```
**Upgrade tenants within an alphabetical range:**
```bash
# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200)
alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head
# Upgrade tenants starting from position 1000 alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head
# Upgrade first 500 tenants alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head
```
**Continue on error (for batch operations):**
```bash
alembic -x upgrade_all_tenants=true -x continue=true upgrade head
```
The tenant range filtering works by:
1. Sorting tenant IDs alphabetically
2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.)
3. Filtering to the specified range of positions
4. Non-tenant schemas (like 'public') are always included

View File

@@ -1,12 +1,12 @@
from typing import Any, Literal
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
@@ -21,14 +21,10 @@ from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA,
TENANT_ID_PREFIX,
)
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine import SqlEngine
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
@@ -73,67 +69,15 @@ def include_object(
return True
def filter_tenants_by_range(
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
) -> list[str]:
"""
Filter tenant IDs by alphabetical position range.
Args:
tenant_ids: List of tenant IDs to filter
start_range: Starting position in alphabetically sorted list (1-based, inclusive)
end_range: Ending position in alphabetically sorted list (1-based, inclusive)
Returns:
Filtered list of tenant IDs in their original order
"""
if start_range is None and end_range is None:
return tenant_ids
# Separate tenant IDs from non-tenant schemas
tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)]
non_tenant_schemas = [
tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX)
]
# Sort tenant schemas alphabetically.
# NOTE: can cause missed schemas if a schema is created in between workers
# fetching of all tenant IDs. We accept this risk for now. Just re-running
# the migration will fix the issue.
sorted_tenant_schemas = sorted(tenant_schemas)
# Apply range filtering (0-based indexing)
start_idx = start_range if start_range is not None else 0
end_idx = end_range if end_range is not None else len(sorted_tenant_schemas)
# Ensure indices are within bounds
start_idx = max(0, start_idx)
end_idx = min(len(sorted_tenant_schemas), end_idx)
# Get the filtered tenant schemas
filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx]
# Combine with non-tenant schemas and preserve original order
filtered_tenants = []
for tenant_id in tenant_ids:
if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas:
filtered_tenants.append(tenant_id)
return filtered_tenants
def get_schema_options() -> (
tuple[bool, bool, bool, int | None, int | None, list[str] | None]
):
def get_schema_options() -> tuple[str, bool, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
if "=" in arg:
key, value = arg.split("=", 1)
x_args[key.strip()] = value.strip()
else:
raise ValueError(f"Invalid argument: {arg}")
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
@@ -141,81 +85,17 @@ def get_schema_options() -> (
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
# Tenant range filtering
tenant_range_start = None
tenant_range_end = None
if "tenant_range_start" in x_args:
try:
tenant_range_start = int(x_args["tenant_range_start"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer."
)
if "tenant_range_end" in x_args:
try:
tenant_range_end = int(x_args["tenant_range_end"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer."
)
# Validate range
if tenant_range_start is not None and tenant_range_end is not None:
if tenant_range_start > tenant_range_end:
raise ValueError(
f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})"
)
# Specific schema names filtering (replaces both schema_name and the old tenant_ids approach)
schemas = None
if "schemas" in x_args:
schema_names_str = x_args["schemas"].strip()
if schema_names_str:
# Split by comma and strip whitespace
schemas = [
name.strip() for name in schema_names_str.split(",") if name.strip()
]
if schemas:
logger.info(f"Specific schema names specified: {schemas}")
# Validate that only one method is used at a time
range_filtering = tenant_range_start is not None or tenant_range_end is not None
specific_filtering = schemas is not None and len(schemas) > 0
if range_filtering and specific_filtering:
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
raise ValueError(
"Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) "
"and specific schema filtering (schemas) at the same time. "
"Please use only one filtering method."
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)
if upgrade_all_tenants and specific_filtering:
raise ValueError(
"Cannot use both upgrade_all_tenants=true and schemas at the same time. "
"Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas."
)
# If any filtering parameters are specified, we're not doing the default single schema migration
if range_filtering:
upgrade_all_tenants = True
# Validate multi-tenant requirements
if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering:
raise ValueError(
"In multi-tenant mode, you must specify either upgrade_all_tenants=true "
"or provide schemas. Cannot run default migration."
)
return (
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
)
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
def do_run_migrations(
@@ -262,17 +142,12 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
@@ -289,50 +164,12 @@ async def run_async_migrations() -> None:
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if schemas:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {schemas}")
i_schema = 0
num_schemas = len(schemas)
for schema in schemas:
i_schema += 1
logger.info(
f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}"
)
try:
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue=true is not set, raising exception!")
raise
logger.warning("--continue=true is set, continuing to next schema.")
elif upgrade_all_tenants:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)
if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)
i_tenant = 0
num_tenants = len(filtered_tenant_schemas)
for schema in filtered_tenant_schemas:
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
@@ -353,13 +190,17 @@ async def run_async_migrations() -> None:
logger.warning("--continue=true is set, continuing to next schema.")
else:
# This should not happen in the new design since we require either
# upgrade_all_tenants=true or schemas in multi-tenant mode
# and for non-multi-tenant mode, we should use schemas with the default schema
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or schemas for specific schemas."
)
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
@@ -380,37 +221,10 @@ def run_migrations_offline() -> None:
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
(
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
) = get_schema_options()
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()
if schemas:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {schemas}")
for schema in schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
elif upgrade_all_tenants:
if upgrade_all_tenants:
engine = create_async_engine(url)
if USE_IAM_AUTH:
@@ -424,19 +238,7 @@ def run_migrations_offline() -> None:
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)
if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)
for schema in filtered_tenant_schemas:
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
@@ -452,12 +254,21 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()
else:
# This should not happen in the new design
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or schemas for specific schemas."
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")

View File

@@ -1,72 +0,0 @@
"""add federated connector tables
Revision ID: 0816326d83aa
Revises: 12635f6655b7
Create Date: 2025-06-29 14:09:45.109518
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0816326d83aa"
down_revision = "12635f6655b7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create federated_connector table
op.create_table(
"federated_connector",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("source", sa.String(), nullable=False),
sa.Column("credentials", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector_oauth_token table
op.create_table(
"federated_connector_oauth_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("token", sa.LargeBinary(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector__document_set table
op.create_table(
"federated_connector__document_set",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("document_set_id", sa.Integer(), nullable=False),
sa.Column("entities", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["document_set_id"], ["document_set.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"federated_connector_id",
"document_set_id",
name="uq_federated_connector_document_set",
),
)
def downgrade() -> None:
# Drop tables in reverse order due to foreign key dependencies
op.drop_table("federated_connector__document_set")
op.drop_table("federated_connector_oauth_token")
op.drop_table("federated_connector")

View File

@@ -1,596 +0,0 @@
"""drive-canonical-ids
Revision ID: 12635f6655b7
Revises: 58c50ef19f08
Create Date: 2025-06-20 14:44:54.241159
"""
from alembic import op
import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
import httpx
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.utils.logger import setup_logger
import os
logger = setup_logger()
# revision identifiers, used by Alembic.
revision = "12635f6655b7"
down_revision = "58c50ef19f08"
branch_labels = None
depends_on = None
SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true"
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
result = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_fetch = result.fetchall()
search_settings = (
SearchSettings(**search_settings_fetch[0]._asdict())
if search_settings_fetch
else None
)
result2 = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_future_fetch = result2.fetchall()
search_settings_future = (
SearchSettings(**search_settings_future_fetch[0]._asdict())
if search_settings_future_fetch
else None
)
if not isinstance(search_settings, SearchSettings):
raise RuntimeError(
"current search settings is of type " + str(type(search_settings))
)
if (
not isinstance(search_settings_future, SearchSettings)
and search_settings_future is not None
):
raise RuntimeError(
"future search settings is of type " + str(type(search_settings_future))
)
return search_settings, search_settings_future
def normalize_google_drive_url(url: str) -> str:
"""Remove query parameters from Google Drive URLs to create canonical document IDs.
NOTE: copied from drive doc_conversion.py
"""
parsed_url = urlparse(url)
parsed_url = parsed_url._replace(query="")
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def get_google_drive_documents_from_database() -> list[dict]:
"""Get all Google Drive documents from the database."""
bind = op.get_bind()
result = bind.execute(
sa.text(
"""
SELECT d.id
FROM document d
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
AND dcc.credential_id = cc.credential_id
JOIN connector c ON cc.connector_id = c.id
WHERE c.source = 'GOOGLE_DRIVE'
"""
)
)
documents = []
for row in result:
documents.append({"document_id": row.id})
return documents
def update_document_id_in_database(
old_doc_id: str, new_doc_id: str, index_name: str
) -> None:
"""Update document IDs in all relevant database tables using copy-and-swap approach."""
bind = op.get_bind()
# print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}")
# Check if new document ID already exists
result = bind.execute(
sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"),
{"new_id": new_doc_id},
)
row = result.fetchone()
if row and row[0] > 0:
# print(f"Document with ID {new_doc_id} already exists, deleting old one")
delete_document_from_db(old_doc_id, index_name)
return
# Step 1: Create a new document row with the new ID (copy all fields from old row)
# Use a conservative approach to handle columns that might not exist in all installations
try:
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
# If the full INSERT fails, try a more basic version with only core columns
logger.warning(f"Full INSERT failed, trying basic version: {e}")
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# Step 2: Update all foreign key references to point to the new ID
# Update document_by_connector_credential_pair table
bind.execute(
sa.text(
"UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}")
# Update search_doc table (stores search results for chat replay)
# This is critical for agent functionality
bind.execute(
sa.text(
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}")
# Update document_retrieval_feedback table (user feedback on documents)
bind.execute(
sa.text(
"UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}")
# Update document__tag table (document-tag relationships)
bind.execute(
sa.text(
"UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}")
# Update user_file table (user uploaded files linked to documents)
bind.execute(
sa.text(
"UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}")
# Update KG and chunk_stats tables (these may not exist in all installations)
try:
# Update kg_entity table
bind.execute(
sa.text(
"UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}")
# Update kg_entity_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship table
bind.execute(
sa.text(
"UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats table
bind.execute(
sa.text(
"UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats ID field which includes document_id
bind.execute(
sa.text(
"""
UPDATE chunk_stats
SET id = REPLACE(id, :old_id, :new_id)
WHERE id LIKE :old_id_pattern
"""
),
{
"new_id": new_doc_id,
"old_id": old_doc_id,
"old_id_pattern": f"{old_doc_id}__%",
},
)
# print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}")
# Step 3: Delete the old document row (this should now be safe since all FKs point to new row)
bind.execute(
sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id}
)
# print(f"Successfully deleted document {old_doc_id} from database")
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict], str | None]:
"""Helper that calls the /document/v1 visit API once and returns (docs, next_token)."""
# Use the same URL as the document API, but with visit-specific params
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
# print(f"Visiting chunks for selection '{selection}' with params {params}")
resp = http_client.get(base_url, params=params, timeout=None)
# print(f"Visited chunks for document {selection}")
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None:
"""Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset)."""
total_deleted = 0
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
try:
resp = http_client.delete(delete_url)
resp.raise_for_status()
total_deleted += 1
except Exception as e:
print(f"Failed to delete chunk {vespa_doc_uuid}: {e}")
if not continuation:
break
def update_document_id_in_vespa(
index_name: str, old_doc_id: str, new_doc_id: str
) -> None:
"""Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging."""
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{old_doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
# print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}")
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
try:
resp = http_client.put(vespa_url, json=update_request)
resp.raise_for_status()
except Exception as e:
print(f"Failed to update chunk {vespa_doc_uuid}: {e}")
raise
if not continuation:
break
def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
# Delete all foreign key references first, then delete the document
try:
bind = op.get_bind()
# Delete from agent-related tables first (order matters due to foreign keys)
# Delete from agent__sub_query__search_doc first since it references search_doc
bind.execute(
sa.text(
"""
DELETE FROM agent__sub_query__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Delete from chat_message__search_doc
bind.execute(
sa.text(
"""
DELETE FROM chat_message__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Now we can safely delete from search_doc
bind.execute(
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from document_by_connector_credential_pair
bind.execute(
sa.text(
"DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id"
),
{"doc_id": current_doc_id},
)
# Delete from other tables that reference this document
bind.execute(
sa.text(
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM user_file WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from KG tables if they exist
try:
bind.execute(
sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"),
{"doc_id_pattern": f"{current_doc_id}__%"},
)
except Exception as e:
logger.warning(
f"Some KG/chunk tables may not exist or failed to delete from: {e}"
)
# Finally delete the document itself
bind.execute(
sa.text("DELETE FROM document WHERE id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete chunks from vespa
delete_document_chunks_from_vespa(index_name, current_doc_id)
except Exception as e:
print(f"Failed to delete duplicate document {current_doc_id}: {e}")
# Continue with other documents instead of failing the entire migration
def upgrade() -> None:
if SKIP_CANON_DRIVE_IDS:
return
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings,
future_search_settings,
)
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"
# Get all Google Drive documents from the database (this is faster and more reliable)
gdrive_documents = get_google_drive_documents_from_database()
if not gdrive_documents:
return
# Track normalized document IDs to detect duplicates
all_normalized_doc_ids = set()
updated_count = 0
for doc_info in gdrive_documents:
current_doc_id = doc_info["document_id"]
normalized_doc_id = normalize_google_drive_url(current_doc_id)
print(f"Processing document {current_doc_id} -> {normalized_doc_id}")
# Check for duplicates
if normalized_doc_id in all_normalized_doc_ids:
# print(f"Deleting duplicate document {current_doc_id}")
delete_document_from_db(current_doc_id, index_name)
continue
all_normalized_doc_ids.add(normalized_doc_id)
# If the document ID already doesn't have query parameters, skip it
if current_doc_id == normalized_doc_id:
# print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters")
continue
try:
# Update both database and Vespa in order
# Database first to ensure consistency
update_document_id_in_database(
current_doc_id, normalized_doc_id, index_name
)
# For Vespa, we can now use the original document IDs since we're using contains matching
update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id)
updated_count += 1
# print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}")
except Exception as e:
print(f"Failed to update document {current_doc_id}: {e}")
if isinstance(e, HTTPStatusError):
print(f"HTTPStatusError: {e}")
print(f"Response: {e.response.text}")
print(f"Status: {e.response.status_code}")
print(f"Headers: {e.response.headers}")
print(f"Request: {e.request.url}")
print(f"Request headers: {e.request.headers}")
# Note: Rollback is complex with copy-and-swap approach since the old document is already deleted
# In case of failure, manual intervention may be required
# Continue with other documents instead of failing the entire migration
continue
logger.info(f"Migration complete. Updated {updated_count} Google Drive documents")
def downgrade() -> None:
# this is a one way migration, so no downgrade.
# It wouldn't make sense to store the extra query parameters
# and duplicate documents to allow a reversal.
pass

View File

@@ -144,34 +144,27 @@ def upgrade() -> None:
def downgrade() -> None:
op.execute("TRUNCATE TABLE index_attempt")
conn = op.get_bind()
inspector = sa.inspect(conn)
existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")}
if "input_type" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "source" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "connector_specific_config" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
@@ -190,12 +183,8 @@ def downgrade() -> None:
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
if "credential_id" in existing_columns:
op.drop_column("index_attempt", "credential_id")
if "connector_id" in existing_columns:
op.drop_column("index_attempt", "connector_id")
op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE")
op.execute("DROP TABLE IF EXISTS credential CASCADE")
op.execute("DROP TABLE IF EXISTS connector CASCADE")
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")
op.drop_table("credential")
op.drop_table("connector")

View File

@@ -1,115 +0,0 @@
"""add_indexing_coordination
Revision ID: 2f95e36923e6
Revises: 0816326d83aa
Create Date: 2025-07-10 16:17:57.762182
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f95e36923e6"
down_revision = "0816326d83aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add database-based coordination fields (replacing Redis fencing)
op.add_column(
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"cancellation_requested",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
# Add batch coordination fields (replacing FileStore state)
op.add_column(
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"completed_batches", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"total_failures_batch_level",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"index_attempt",
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
)
# Progress tracking for stall detection
op.add_column(
"index_attempt",
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column(
"last_batches_completed_count",
sa.Integer(),
nullable=False,
server_default="0",
),
)
# Heartbeat tracking for worker liveness detection
op.add_column(
"index_attempt",
sa.Column(
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
)
# Add index for coordination queries
op.create_index(
"ix_index_attempt_active_coordination",
"index_attempt",
["connector_credential_pair_id", "search_settings_id", "status"],
)
def downgrade() -> None:
# Remove the new index
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
# Remove the new columns
op.drop_column("index_attempt", "last_batches_completed_count")
op.drop_column("index_attempt", "last_progress_time")
op.drop_column("index_attempt", "last_heartbeat_time")
op.drop_column("index_attempt", "last_heartbeat_value")
op.drop_column("index_attempt", "heartbeat_counter")
op.drop_column("index_attempt", "total_chunks")
op.drop_column("index_attempt", "total_failures_batch_level")
op.drop_column("index_attempt", "completed_batches")
op.drop_column("index_attempt", "total_batches")
op.drop_column("index_attempt", "cancellation_requested")
op.drop_column("index_attempt", "celery_task_id")

View File

@@ -1,136 +0,0 @@
"""update_kg_trigger_functions
Revision ID: 36e9220ab794
Revises: c9e2cd766c29
Create Date: 2025-06-22 17:33:25.833733
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
down_revision = "c9e2cd766c29"
branch_labels = None
depends_on = None
def _get_tenant_contextvar(session: Session) -> str:
"""Get the current schema for the migration"""
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
if isinstance(current_tenant, str):
return current_tenant
else:
raise ValueError("Current tenant is not a string")
def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
# Create kg_entity trigger to update kg_entity.name and its trigrams
tenant_id = _get_tenant_contextvar(session)
alphanum_pattern = r"[^a-z0-9]+"
truncate_length = 1000
function = "update_kg_entity_name"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
name text;
cleaned_name text;
BEGIN
-- Set name to semantic_id if document_id is not NULL
IF NEW.document_id IS NOT NULL THEN
SELECT lower(semantic_id) INTO name
FROM "{tenant_id}".document
WHERE id = NEW.document_id;
ELSE
name = lower(NEW.name);
END IF;
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity')
op.execute(
f"""
CREATE TRIGGER {trigger}
BEFORE INSERT OR UPDATE OF name
ON "{tenant_id}".kg_entity
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
# Create kg_entity trigger to update kg_entity.name and its trigrams
function = "update_kg_entity_name_from_doc"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
doc_name text;
cleaned_name text;
BEGIN
doc_name = lower(NEW.semantic_id);
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
doc_name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams for all entities referencing this document
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document')
op.execute(
f"""
CREATE TRIGGER {trigger}
AFTER UPDATE OF semantic_id
ON "{tenant_id}".document
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
def downgrade() -> None:
pass

View File

@@ -21,14 +21,22 @@ depends_on = None
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Adds indexes to both chat_message and chat_session tables for comprehensive search
# 3. Note: CONCURRENTLY was removed due to operational issues
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
@@ -44,9 +52,12 @@ def upgrade() -> None:
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chat_message_tsv
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
@@ -61,9 +72,12 @@ def upgrade() -> None:
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chat_session_desc_tsv
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
@@ -71,9 +85,12 @@ def upgrade() -> None:
def downgrade() -> None:
# Drop the indexes first
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")

View File

@@ -1,30 +0,0 @@
"""add_doc_metadata_field_in_document_model
Revision ID: 3fc5d75723b3
Revises: 2f95e36923e6
Create Date: 2025-07-28 18:45:37.985406
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3fc5d75723b3"
down_revision = "2f95e36923e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document",
sa.Column(
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
)
def downgrade() -> None:
op.drop_column("document", "doc_metadata")

View File

@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
@@ -80,7 +80,6 @@ def upgrade() -> None:
)
)
op.execute("DROP TABLE IF EXISTS kg_config CASCADE")
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
@@ -124,7 +123,6 @@ def upgrade() -> None:
],
)
op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE")
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
@@ -158,7 +156,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE")
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
@@ -197,7 +194,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE")
# Create KGRelationshipTypeExtractionStaging table
op.create_table(
"kg_relationship_type_extraction_staging",
@@ -231,8 +227,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_entity CASCADE")
# Create KGEntity table
op.create_table(
"kg_entity",
@@ -287,7 +281,6 @@ def upgrade() -> None:
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE")
# Create KGEntityExtractionStaging table
op.create_table(
"kg_entity_extraction_staging",
@@ -337,7 +330,6 @@ def upgrade() -> None:
["name", "entity_type_id_name"],
)
op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE")
# Create KGRelationship table
op.create_table(
"kg_relationship",
@@ -379,7 +371,6 @@ def upgrade() -> None:
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE")
# Create KGRelationshipExtractionStaging table
op.create_table(
"kg_relationship_extraction_staging",
@@ -423,7 +414,6 @@ def upgrade() -> None:
["source_node", "target_node"],
)
op.execute("DROP TABLE IF EXISTS kg_term CASCADE")
# Create KGTerm table
op.create_table(
"kg_term",
@@ -477,11 +467,11 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
"ON kg_entity USING GIN (name_trigrams)"
)
@@ -518,7 +508,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -563,7 +553,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
@@ -635,8 +625,9 @@ def downgrade() -> None:
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
# Drop index
op.execute("DROP INDEX IF EXISTS idx_kg_entity_clustering_trigrams")
op.execute("DROP INDEX IF EXISTS idx_kg_entity_normalization_trigrams")
op.execute("COMMIT") # Commit to allow CONCURRENTLY
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")

View File

@@ -1,90 +0,0 @@
"""add stale column to external user group tables
Revision ID: 58c50ef19f08
Revises: 7b9b952abdf6
Create Date: 2025-06-25 14:08:14.162380
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "58c50ef19f08"
down_revision = "7b9b952abdf6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the stale column with default value False to user__external_user_group_id
op.add_column(
"user__external_user_group_id",
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
)
# Create index for efficient querying of stale rows by cc_pair_id
op.create_index(
"ix_user__external_user_group_id_cc_pair_id_stale",
"user__external_user_group_id",
["cc_pair_id", "stale"],
unique=False,
)
# Create index for efficient querying of all stale rows
op.create_index(
"ix_user__external_user_group_id_stale",
"user__external_user_group_id",
["stale"],
unique=False,
)
# Add the stale column with default value False to public_external_user_group
op.add_column(
"public_external_user_group",
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
)
# Create index for efficient querying of stale rows by cc_pair_id
op.create_index(
"ix_public_external_user_group_cc_pair_id_stale",
"public_external_user_group",
["cc_pair_id", "stale"],
unique=False,
)
# Create index for efficient querying of all stale rows
op.create_index(
"ix_public_external_user_group_stale",
"public_external_user_group",
["stale"],
unique=False,
)
def downgrade() -> None:
# Drop the indices for public_external_user_group first
op.drop_index(
"ix_public_external_user_group_stale", table_name="public_external_user_group"
)
op.drop_index(
"ix_public_external_user_group_cc_pair_id_stale",
table_name="public_external_user_group",
)
# Drop the stale column from public_external_user_group
op.drop_column("public_external_user_group", "stale")
# Drop the indices for user__external_user_group_id
op.drop_index(
"ix_user__external_user_group_id_stale",
table_name="user__external_user_group_id",
)
op.drop_index(
"ix_user__external_user_group_id_cc_pair_id_stale",
table_name="user__external_user_group_id",
)
# Drop the stale column from user__external_user_group_id
op.drop_column("user__external_user_group_id", "stale")

View File

@@ -1,132 +0,0 @@
"""add file names to file connector config
Revision ID: 62c3a055a141
Revises: 3fc5d75723b3
Create Date: 2025-07-30 17:01:24.417551
"""
from alembic import op
import sqlalchemy as sa
import json
import os
import logging
# revision identifiers, used by Alembic.
revision = "62c3a055a141"
down_revision = "3fc5d75723b3"
branch_labels = None
depends_on = None
SKIP_FILE_NAME_MIGRATION = (
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
)
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
if SKIP_FILE_NAME_MIGRATION:
logger.info(
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
)
return
logger.info("Running file name migration")
# Get connection
conn = op.get_bind()
# Get all FILE connectors with their configs
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Get file_locations list
file_locations = config.get("file_locations", [])
# Get display names for each file_id
file_names = []
for file_id in file_locations:
result = conn.execute(
sa.text(
"""
SELECT display_name
FROM file_record
WHERE file_id = :file_id
"""
),
{"file_id": file_id},
).fetchone()
if result:
file_names.append(result[0])
else:
file_names.append(file_id) # Should not happen
# Add file_names to config
new_config = dict(config)
new_config["file_names"] = file_names
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get connection
conn = op.get_bind()
# Remove file_names from all FILE connectors
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Remove file_names if it exists
if "file_names" in config:
new_config = dict(config)
del new_config["file_names"]
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{
"connector_id": connector_id,
"new_config": json.dumps(new_config),
},
)

View File

@@ -1,318 +0,0 @@
"""update-entities
Revision ID: 7b9b952abdf6
Revises: 36e9220ab794
Create Date: 2025-06-23 20:24:08.139201
"""
import json
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7b9b952abdf6"
down_revision = "36e9220ab794"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# new entity type metadata_attribute_conversion
new_entity_type_conversion = {
"LINEAR": {
"team": {"name": "team", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"priority": {
"name": "priority",
"keep": True,
"implication_property": None,
},
"estimate": {
"name": "estimate",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"started_at": {
"name": "started_at",
"keep": True,
"implication_property": None,
},
"completed_at": {
"name": "completed_at",
"keep": True,
"implication_property": None,
},
"due_date": {
"name": "due_date",
"keep": True,
"implication_property": None,
},
"creator": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignee": {
"name": "assignee",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"JIRA": {
"issuetype": {
"name": "subtype",
"keep": True,
"implication_property": None,
},
"status": {"name": "status", "keep": True, "implication_property": None},
"priority": {
"name": "priority",
"keep": True,
"implication_property": None,
},
"project_name": {
"name": "project",
"keep": True,
"implication_property": None,
},
"created": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"resolution_date": {
"name": "completed_at",
"keep": True,
"implication_property": None,
},
"duedate": {"name": "due_date", "keep": True, "implication_property": None},
"reporter_email": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignee_email": {
"name": "assignee",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
"key": {"name": "key", "keep": True, "implication_property": None},
"parent": {"name": "parent", "keep": True, "implication_property": None},
},
"GITHUB_PR": {
"repo": {"name": "repository", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"num_commits": {
"name": "num_commits",
"keep": True,
"implication_property": None,
},
"num_files_changed": {
"name": "num_files_changed",
"keep": True,
"implication_property": None,
},
"labels": {"name": "labels", "keep": True, "implication_property": None},
"merged": {"name": "merged", "keep": True, "implication_property": None},
"merged_at": {
"name": "merged_at",
"keep": True,
"implication_property": None,
},
"closed_at": {
"name": "closed_at",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated_at": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"user": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignees": {
"name": "assignees",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"GITHUB_ISSUE": {
"repo": {"name": "repository", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"labels": {"name": "labels", "keep": True, "implication_property": None},
"closed_at": {
"name": "closed_at",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated_at": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"user": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignees": {
"name": "assignees",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"FIREFLIES": {},
"ACCOUNT": {},
"OPPORTUNITY": {
"name": {"name": "name", "keep": True, "implication_property": None},
"stage_name": {"name": "stage", "keep": True, "implication_property": None},
"type": {"name": "type", "keep": True, "implication_property": None},
"amount": {"name": "amount", "keep": True, "implication_property": None},
"fiscal_year": {
"name": "fiscal_year",
"keep": True,
"implication_property": None,
},
"fiscal_quarter": {
"name": "fiscal_quarter",
"keep": True,
"implication_property": None,
},
"is_closed": {
"name": "is_closed",
"keep": True,
"implication_property": None,
},
"close_date": {
"name": "close_date",
"keep": True,
"implication_property": None,
},
"probability": {
"name": "close_probability",
"keep": True,
"implication_property": None,
},
"created_date": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"last_modified_date": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"account": {
"name": "account",
"keep": False,
"implication_property": {
"implied_entity_type": "ACCOUNT",
"implied_relationship_name": "is_account_of",
},
},
},
"VENDOR": {},
"EMPLOYEE": {},
}
current_entity_types = conn.execute(
sa.text("SELECT id_name, attributes from kg_entity_type")
).all()
for entity_type, attributes in current_entity_types:
# delete removed entity types
if entity_type not in new_entity_type_conversion:
op.execute(
sa.text(f"DELETE FROM kg_entity_type WHERE id_name = '{entity_type}'")
)
continue
# update entity type attributes
if "metadata_attributes" in attributes:
del attributes["metadata_attributes"]
attributes["metadata_attribute_conversion"] = new_entity_type_conversion[
entity_type
]
attributes_str = json.dumps(attributes).replace("'", "''")
op.execute(
sa.text(
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
f"WHERE id_name = '{entity_type}'"
),
)
def downgrade() -> None:
conn = op.get_bind()
current_entity_types = conn.execute(
sa.text("SELECT id_name, attributes from kg_entity_type")
).all()
for entity_type, attributes in current_entity_types:
conversion = {}
if "metadata_attribute_conversion" in attributes:
conversion = attributes.pop("metadata_attribute_conversion")
attributes["metadata_attributes"] = {
attr: prop["name"] for attr, prop in conversion.items() if prop["keep"]
}
attributes_str = json.dumps(attributes).replace("'", "''")
op.execute(
sa.text(
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
f"WHERE id_name = '{entity_type}'"
),
)

View File

@@ -1,315 +0,0 @@
"""modify_file_store_for_external_storage
Revision ID: c9e2cd766c29
Revises: 03bf8be6b53a
Create Date: 2025-06-13 14:02:09.867679
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy import text
from typing import cast, Any
from botocore.exceptions import ClientError
from onyx.db._deprecated.pg_file_store import delete_lobj_by_id, read_lobj
from onyx.file_store.file_store import get_s3_file_store
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
# revision identifiers, used by Alembic.
revision = "c9e2cd766c29"
down_revision = "03bf8be6b53a"
branch_labels = None
depends_on = None
def upgrade() -> None:
try:
# Modify existing file_store table to support external storage
op.rename_table("file_store", "file_record")
# Make lobj_oid nullable (for external storage files)
op.alter_column("file_record", "lobj_oid", nullable=True)
# Add external storage columns with generic names
op.add_column(
"file_record", sa.Column("bucket_name", sa.String(), nullable=True)
)
op.add_column(
"file_record", sa.Column("object_key", sa.String(), nullable=True)
)
# Add timestamps for tracking
op.add_column(
"file_record",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.add_column(
"file_record",
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.alter_column("file_record", "file_name", new_column_name="file_id")
except Exception as e:
if "does not exist" in str(e) or 'relation "file_store" does not exist' in str(
e
):
print(
f"Ran into error - {e}. Likely means we had a partial success in the past, continuing..."
)
else:
raise
print(
"External storage configured - migrating files from PostgreSQL to external storage..."
)
# if we fail midway through this, we'll have a partial success. Running the migration
# again should allow us to continue.
_migrate_files_to_external_storage()
print("File migration completed successfully!")
# Remove lobj_oid column
op.drop_column("file_record", "lobj_oid")
def downgrade() -> None:
"""Revert schema changes and migrate files from external storage back to PostgreSQL large objects."""
print(
"Reverting to PostgreSQL-backed file store migrating files from external storage …"
)
# 1. Ensure `lobj_oid` exists on the current `file_record` table (nullable for now).
op.add_column("file_record", sa.Column("lobj_oid", sa.Integer(), nullable=True))
# 2. Move content from external storage back into PostgreSQL large objects (table is still
# called `file_record` so application code continues to work during the copy).
try:
_migrate_files_to_postgres()
except Exception:
print("Error during downgrade migration, rolling back …")
op.drop_column("file_record", "lobj_oid")
raise
# 3. After migration every row should now have `lobj_oid` populated mark NOT NULL.
op.alter_column("file_record", "lobj_oid", nullable=False)
# 4. Remove columns that are only relevant to external storage.
op.drop_column("file_record", "updated_at")
op.drop_column("file_record", "created_at")
op.drop_column("file_record", "object_key")
op.drop_column("file_record", "bucket_name")
# 5. Rename `file_id` back to `file_name` (still on `file_record`).
op.alter_column("file_record", "file_id", new_column_name="file_name")
# 6. Finally, rename the table back to its original name expected by the legacy codebase.
op.rename_table("file_record", "file_store")
print(
"Downgrade migration completed files are now stored inside PostgreSQL again."
)
# -----------------------------------------------------------------------------
# Helper: migrate from external storage (S3/MinIO) back into PostgreSQL large objects
def _migrate_files_to_postgres() -> None:
"""Move any files whose content lives in external S3-compatible storage back into PostgreSQL.
The logic mirrors *inverse* of `_migrate_files_to_external_storage` used on upgrade.
"""
# Obtain DB session from Alembic context
bind = op.get_bind()
session = Session(bind=bind)
# Fetch rows that have external storage pointers (bucket/object_key not NULL)
result = session.execute(
text(
"SELECT file_id, bucket_name, object_key FROM file_record "
"WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
)
)
files_to_migrate = [row[0] for row in result.fetchall()]
total_files = len(files_to_migrate)
if total_files == 0:
print("No files found in external storage to migrate back to PostgreSQL.")
return
print(f"Found {total_files} files to migrate back to PostgreSQL large objects.")
_set_tenant_contextvar(session)
migrated_count = 0
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store()
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
# Read file content from external storage (always binary)
try:
file_io = external_store.read_file(
file_id=file_id, mode="b", use_tempfile=True
)
file_io.seek(0)
# Import lazily to avoid circular deps at Alembic runtime
from onyx.db._deprecated.pg_file_store import (
create_populate_lobj,
) # noqa: E402
# Create new Postgres large object and populate it
lobj_oid = create_populate_lobj(content=file_io, db_session=session)
# Update DB row: set lobj_oid, clear bucket/object_key
session.execute(
text(
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, "
"object_key = NULL WHERE file_id = :file_id"
),
{"lobj_oid": lobj_oid, "file_id": file_id},
)
except ClientError as e:
if "NoSuchKey" in str(e):
print(
f"File {file_id} not found in external storage. Deleting from database."
)
session.execute(
text("DELETE FROM file_record WHERE file_id = :file_id"),
{"file_id": file_id},
)
else:
raise
migrated_count += 1
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
# Flush the SQLAlchemy session so statements are sent to the DB, but **do not**
# commit the transaction. The surrounding Alembic migration will commit once
# the *entire* downgrade succeeds. This keeps the whole downgrade atomic and
# avoids leaving the database in a partially-migrated state if a later schema
# operation fails.
session.flush()
print(
f"Migration back to PostgreSQL completed: {migrated_count} files staged for commit."
)
def _migrate_files_to_external_storage() -> None:
"""Migrate files from PostgreSQL large objects to external storage"""
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store()
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(
text(
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL "
"AND bucket_name IS NULL AND object_key IS NULL"
)
)
files_to_migrate = [row[0] for row in result.fetchall()]
total_files = len(files_to_migrate)
if total_files == 0:
print("No files found in PostgreSQL storage to migrate.")
return
# might need to move this above the if statement when creating a new multi-tenant
# system. VERY extreme edge case.
external_store.initialize()
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
_set_tenant_contextvar(session)
migrated_count = 0
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
# Read file record to get metadata
file_record = session.execute(
text("SELECT * FROM file_record WHERE file_id = :file_id"),
{"file_id": file_id},
).fetchone()
if file_record is None:
print(f"File {file_id} not found in PostgreSQL storage.")
continue
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
# Read file content from PostgreSQL
try:
file_content = read_lobj(
lobj_id, db_session=session, mode="b", use_tempfile=True
)
except Exception as e:
if "large object" in str(e) and "does not exist" in str(e):
print(f"File {file_id} not found in PostgreSQL storage.")
continue
else:
raise
# Handle file_metadata type conversion
file_metadata = None
if file_metadata is not None:
if isinstance(file_metadata, dict):
file_metadata = file_metadata
else:
# Convert other types to dict if possible, otherwise None
try:
file_metadata = dict(file_record.file_metadata) # type: ignore
except (TypeError, ValueError):
file_metadata = None
# Save to external storage (this will handle the database record update and cleanup)
# NOTE: this WILL .commit() the transaction.
external_store.save_file(
file_id=file_id,
content=file_content,
display_name=file_record.display_name,
file_origin=file_record.file_origin,
file_type=file_record.file_type,
file_metadata=file_metadata,
)
delete_lobj_by_id(lobj_id, db_session=session)
migrated_count += 1
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
# See note above flush but do **not** commit so the outer Alembic transaction
# controls atomicity.
session.flush()
print(
f"Migration completed: {migrated_count} files staged for commit to external storage."
)
def _set_tenant_contextvar(session: Session) -> None:
"""Set the tenant contextvar to the default schema"""
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
print(f"Migrating files for tenant: {current_tenant}")
CURRENT_TENANT_ID_CONTEXTVAR.set(current_tenant)

View File

@@ -11,7 +11,7 @@ import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.

View File

@@ -18,13 +18,11 @@ depends_on: None = None
def upgrade() -> None:
op.execute("DROP TABLE IF EXISTS document CASCADE")
op.create_table(
"document",
sa.Column("id", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS chunk CASCADE")
op.create_table(
"chunk",
sa.Column("id", sa.String(), nullable=False),
@@ -45,7 +43,6 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id", "document_store_type"),
)
op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE")
op.create_table(
"deletion_attempt",
sa.Column("id", sa.Integer(), nullable=False),
@@ -87,7 +84,6 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE")
op.create_table(
"document_by_connector_credential_pair",
sa.Column("id", sa.String(), nullable=False),
@@ -110,10 +106,7 @@ def upgrade() -> None:
def downgrade() -> None:
# upstream tables first
op.drop_table("document_by_connector_credential_pair")
op.drop_table("deletion_attempt")
op.drop_table("chunk")
# Alembic op.drop_table() has no "cascade" flag issue raw SQL
op.execute("DROP TABLE IF EXISTS document CASCADE")
op.drop_table("document")

View File

@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine import build_connection_string
from onyx.db.models import PublicBase
# this is the Alembic Config object, which provides

View File

@@ -16,7 +16,7 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
@@ -35,13 +35,7 @@ logger = setup_logger()
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
self: Task, *, start: datetime, end: datetime, start_time: datetime
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
@@ -91,7 +85,8 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
get_default_file_store(db_session).save_file(
file_name=report_name,
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
@@ -101,7 +96,6 @@ def export_query_history_task(
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(

View File

@@ -13,7 +13,7 @@ from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task

View File

@@ -20,36 +20,39 @@ from shared_configs.configs import MULTI_TENANT
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
},
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
},
},
},
]
]
)
ee_tasks_to_schedule: list[dict] = []

View File

@@ -6,7 +6,7 @@ from celery import shared_task
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.utils.logger import setup_logger

View File

@@ -13,7 +13,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine import get_all_tenant_ids
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -47,10 +46,9 @@ from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -59,9 +57,7 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.utils import DocumentRow
from onyx.db.utils import is_retryable_sqlalchemy_error
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -77,7 +73,6 @@ from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -92,24 +87,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 300 # seconds
if not MULTI_TENANT:
return base_expiration
try:
beat_multiplier = OnyxRuntime.get_beat_multiplier()
except Exception:
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
return int(base_expiration * beat_multiplier)
"""Jobs / utils for kicking off doc permissions sync tasks."""
@@ -217,11 +194,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
"Exception while validating permission sync fences"
)
r.set(
OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES,
1,
ex=_get_fence_validation_block_expiration(),
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
@@ -425,7 +398,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.cc_pair_id}",
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)
@@ -452,7 +425,6 @@ def connector_permission_sync_generator_task(
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
cc_pair.access_type,
db_session,
enforce_creation=False,
)
@@ -501,31 +473,16 @@ def connector_permission_sync_generator_task(
# this is can be used to determine documents that are "missing" and thus
# should no longer be accessible. The decision as to whether we should find
# every document during the doc sync process is connector-specific.
def fetch_all_existing_docs_fn(
sort_order: SortOrder | None = None,
) -> list[DocumentRow]:
result = get_documents_for_connector_credential_pair_limited_columns(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
sort_order=sort_order,
)
return list(result)
def fetch_all_existing_docs_ids_fn() -> list[str]:
result = get_document_ids_for_connector_credential_pair(
def fetch_all_existing_docs_fn() -> list[str]:
return get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
return result
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
document_external_accesses = doc_sync_func(
cc_pair,
fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn,
callback,
cc_pair, fetch_all_existing_docs_fn, callback
)
task_logger.info(
@@ -640,6 +597,91 @@ def document_update_permissions(
return True
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
# large permissions through celery (over 1MB in size)
# @shared_task(
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
# time_limit=LIGHT_TIME_LIMIT,
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
# bind=True,
# )
# def update_external_document_permissions_task(
# self: Task,
# tenant_id: str,
# serialized_doc_external_access: dict,
# source_string: str,
# connector_id: int,
# credential_id: int,
# ) -> bool:
# start = time.monotonic()
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
# document_external_access = DocExternalAccess.from_dict(
# serialized_doc_external_access
# )
# doc_id = document_external_access.doc_id
# external_access = document_external_access.external_access
# try:
# with get_session_with_current_tenant() as db_session:
# # Add the users to the DB if they don't exist
# batch_add_ext_perm_user_if_not_exists(
# db_session=db_session,
# emails=list(external_access.external_user_emails),
# continue_on_error=True,
# )
# # Then upsert the document's external permissions
# created_new_doc = upsert_document_external_perms(
# db_session=db_session,
# doc_id=doc_id,
# external_access=external_access,
# source_type=DocumentSource(source_string),
# )
# if created_new_doc:
# # If a new document was created, we associate it with the cc_pair
# upsert_document_by_connector_credential_pair(
# db_session=db_session,
# connector_id=connector_id,
# credential_id=credential_id,
# document_ids=[doc_id],
# )
# elapsed = time.monotonic() - start
# task_logger.info(
# f"connector_id={connector_id} "
# f"doc={doc_id} "
# f"action=update_permissions "
# f"elapsed={elapsed:.2f}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
# except Exception as e:
# error_msg = format_error_for_logging(e)
# task_logger.warning(
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
# )
# task_logger.exception(
# f"update_external_document_permissions_task exceptioned: "
# f"connector_id={connector_id} doc_id={doc_id}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
# finally:
# task_logger.info(
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
# )
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
# return False
# task_logger.info(
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
# )
# return True
def validate_permission_sync_fences(
tenant_id: str,
r: Redis,

View File

@@ -20,9 +20,7 @@ from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils imp
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.db.external_perm import mark_old_external_groups_as_stale
from ee.onyx.db.external_perm import remove_stale_external_groups
from ee.onyx.db.external_perm import upsert_external_groups
from ee.onyx.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.onyx.external_permissions.sync_params import (
get_all_cc_pair_agnostic_group_sync_sources,
)
@@ -30,7 +28,6 @@ from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
@@ -42,8 +39,9 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -58,34 +56,19 @@ from onyx.redis.redis_connector_ext_group_sync import (
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_EXTERNAL_GROUP_BATCH_SIZE = 100
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 300 # seconds
if not MULTI_TENANT:
return base_expiration
try:
beat_multiplier = OnyxRuntime.get_beat_multiplier()
except Exception:
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
return int(base_expiration * beat_multiplier)
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@@ -215,11 +198,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
"Exception while validating external group sync fences"
)
r.set(
OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES,
1,
ex=_get_fence_validation_block_expiration(),
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -383,7 +362,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.cc_pair_id}",
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
@@ -398,12 +377,63 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
_perform_external_group_sync(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
)
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
msg = (
f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
)
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if sync_config.group_sync_config is None:
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
logger.debug(f"New external user groups: {external_user_groups}")
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -445,81 +475,6 @@ def connector_external_group_sync_generator_task(
)
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
) -> None:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if sync_config.group_sync_config is None:
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
logger.info(
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
)
mark_old_external_groups_as_stale(db_session, cc_pair_id)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_group_batch: list[ExternalUserGroup] = []
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
external_user_group_batch.append(external_user_group)
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
logger.debug(
f"New external user groups: {external_user_group_batch}"
)
upsert_external_groups(
db_session=db_session,
cc_pair_id=cc_pair_id,
external_groups=external_user_group_batch,
source=cc_pair.connector.source,
)
external_user_group_batch = []
if external_user_group_batch:
logger.debug(f"New external user groups: {external_user_group_batch}")
upsert_external_groups(
db_session=db_session,
cc_pair_id=cc_pair_id,
external_groups=external_user_group_batch,
source=cc_pair.connector.source,
)
except Exception as e:
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e
logger.info(
f"Removing stale external groups for {source_type} for cc_pair: {cc_pair_id}"
)
remove_stale_external_groups(db_session, cc_pair_id)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
def validate_external_group_sync_fences(
tenant_id: str,
celery_app: Celery,

View File

@@ -19,7 +19,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT

View File

@@ -53,16 +53,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
)
#####
# JIRA
#####
# In seconds, default is 30 minutes
JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
)
#####
# Google Drive
#####
@@ -71,19 +61,6 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
)
#####
# GitHub
#####
# In seconds, default is 5 minutes
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
#####
# Slack
#####
@@ -94,15 +71,6 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
#####
# Teams
#####
# In seconds, default is 5 minutes
TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
####
# Celery Job Frequency
####

View File

@@ -1,28 +0,0 @@
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.interfaces import BaseConnector
def validate_confluence_perm_sync(connector: ConfluenceConnector) -> None:
"""
Validate that the connector is configured correctly for permissions syncing.
"""
def validate_drive_perm_sync(connector: GoogleDriveConnector) -> None:
"""
Validate that the connector is configured correctly for permissions syncing.
"""
def validate_perm_sync(connector: BaseConnector) -> None:
"""
Override this if your connector needs to validate permissions syncing.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
if isinstance(connector, ConfluenceConnector):
validate_confluence_perm_sync(connector)
elif isinstance(connector, GoogleDriveConnector):
validate_drive_perm_sync(connector)

View File

@@ -4,7 +4,6 @@ from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
@@ -63,41 +62,20 @@ def delete_public_external_group_for_cc_pair__no_commit(
)
def mark_old_external_groups_as_stale(
def replace_user__ext_group_for_cc_pair(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
update(User__ExternalUserGroupId)
.where(User__ExternalUserGroupId.cc_pair_id == cc_pair_id)
.values(stale=True)
)
db_session.execute(
update(PublicExternalUserGroup)
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
.values(stale=True)
)
def upsert_external_groups(
db_session: Session,
cc_pair_id: int,
external_groups: list[ExternalUserGroup],
group_defs: list[ExternalUserGroup],
source: DocumentSource,
) -> None:
"""
Performs a true upsert operation for external user groups:
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
- For new groups, inserts them with stale=False
- For public groups, uses upsert logic as well
This function clears all existing external user group relations for a given cc_pair_id
and replaces them with the new group definitions and commits the changes.
"""
# If there are no groups to add, return early
if not external_groups:
return
# collect all emails from all groups to batch add all users at once for efficiency
all_group_member_emails = set()
for external_group in external_groups:
for external_group in group_defs:
for user_email in external_group.user_emails:
all_group_member_emails.add(user_email)
@@ -108,17 +86,26 @@ def upsert_external_groups(
emails=list(all_group_member_emails),
)
# map emails to ids
email_id_map = {user.email.lower(): user.id for user in all_group_members}
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
delete_public_external_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# Process each external group
for external_group in external_groups:
# map emails to ids
email_id_map = {user.email: user.id for user in all_group_members}
# use these ids to create new external user group relations relating group_id to user_ids
new_external_permissions: list[User__ExternalUserGroupId] = []
new_public_external_groups: list[PublicExternalUserGroup] = []
for external_group in group_defs:
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
# Handle user-group mappings
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email.lower())
if user_id is None:
@@ -127,71 +114,24 @@ def upsert_external_groups(
f" with email {user_email} not found"
)
continue
# Check if the user-group mapping already exists
existing_user_group = db_session.scalar(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id
== external_group_id,
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
)
)
if existing_user_group:
# Update existing record
existing_user_group.stale = False
else:
# Insert new record
new_user_group = User__ExternalUserGroupId(
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_user_group)
# Handle public group if needed
if external_group.gives_anyone_access:
# Check if the public group already exists
existing_public_group = db_session.scalar(
select(PublicExternalUserGroup).where(
PublicExternalUserGroup.external_user_group_id == external_group_id,
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
)
)
if existing_public_group:
# Update existing record
existing_public_group.stale = False
else:
# Insert new record
new_public_group = PublicExternalUserGroup(
if external_group.gives_anyone_access:
new_public_external_groups.append(
PublicExternalUserGroup(
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_public_group)
)
db_session.commit()
def remove_stale_external_groups(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
delete(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
User__ExternalUserGroupId.stale.is_(True),
)
)
db_session.execute(
delete(PublicExternalUserGroup).where(
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
PublicExternalUserGroup.stale.is_(True),
)
)
db_session.add_all(new_external_permissions)
db_session.add_all(new_public_external_groups)
db_session.commit()

View File

@@ -114,24 +114,12 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
report_display_name: str,
db_session: Session,
report_name: str,
) -> IO:
"""
Get the usage report data from the file store.
Args:
db_session: The database session.
report_display_name: The display name of the usage report. Also assumes
that the file is stored with this as the ID in the file store.
Returns:
The usage report data.
"""
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True
)
return file_store.read_file(file_name=report_name, mode="b", use_tempfile=True)
def write_usage_report(

View File

@@ -128,14 +128,11 @@ def validate_object_creation_for_user(
target_group_ids: list[int] | None = None,
object_is_public: bool | None = None,
object_is_perm_sync: bool | None = None,
object_is_owned_by_user: bool = False,
object_is_new: bool = False,
) -> None:
"""
All users can create/edit permission synced objects if they don't specify a group
All admin actions are allowed.
Curators and global curators can create public objects.
Prevents other non-admins from creating/editing:
Prevents non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
@@ -146,23 +143,13 @@ def validate_object_creation_for_user(
if not user or user.role == UserRole.ADMIN:
return
# Allow curators and global curators to create public objects
# w/o associated groups IF the object is new/owned by them
if (
object_is_public
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
and (object_is_new or object_is_owned_by_user)
):
return
if object_is_public and user.role == UserRole.BASIC:
detail = "User does not have permission to create public objects"
if object_is_public:
detail = "User does not have permission to create public credentials"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
if not target_group_ids:
detail = "Curators must specify 1+ groups"
logger.error(detail)

View File

@@ -18,9 +18,9 @@
<!-- <document type="danswer_chunk" mode="index" /> -->
{{ document_elements }}
</documents>
<nodes count="60">
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
disk="475.0Gb" />
<nodes count="75">
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
disk="474.0Gb" />
</nodes>
<engine>
<proton>

View File

@@ -6,12 +6,11 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -20,13 +19,9 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -34,6 +29,7 @@ def confluence_doc_sync(
Compares fetched documents against existing documents in the DB for the connector.
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
"""
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
@@ -43,11 +39,52 @@ def confluence_doc_sync(
)
confluence_connector.set_credentials_provider(provider)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.CONFLUENCE,
slim_connector=confluence_connector,
label=CONFLUENCE_DOC_SYNC_LABEL,
)
slim_docs: list[SlimDocument] = []
logger.info("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
# Find documents that are no longer accessible in Confluence
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
existing_doc_ids = fetch_all_existing_docs_fn()
# Find missing doc IDs
fetched_doc_ids = {doc.id for doc in slim_docs}
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
# Yield access removal for missing docs. Better to be safe.
if missing_doc_ids:
logger.warning(
f"Found {len(missing_doc_ids)} documents that are in the DB but "
"not present in Confluence fetch. Making them inaccessible."
)
for missing_id in missing_doc_ids:
logger.warning(f"Removing access for document ID: {missing_id}")
yield DocExternalAccess(
doc_id=missing_id,
external_access=ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
),
)
for doc in slim_docs:
if not doc.external_access:
raise RuntimeError(f"No external access found for document ID: {doc.id}")
yield DocExternalAccess(
doc_id=doc.id,
external_access=doc.external_access,
)
logger.info("Finished confluence doc sync")

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
@@ -67,7 +65,7 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
) -> list[ExternalUserGroup]:
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
@@ -91,10 +89,10 @@ def confluence_group_sync(
confluence_client=confluence_client,
cc_pair_id=cc_pair.id,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()
for group_id, group_member_emails in group_member_email_map.items():
yield (
onyx_groups.append(
ExternalUserGroup(
id=group_id,
user_emails=list(group_member_emails),
@@ -109,4 +107,6 @@ def confluence_group_sync(
id=ALL_CONF_EMAILS_GROUP_NAME,
user_emails=list(all_found_emails),
)
yield all_found_group
onyx_groups.append(all_found_group)
return onyx_groups

View File

@@ -1,294 +0,0 @@
import json
from collections.abc import Generator
from github import Github
from github.Repository import Repository
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
from ee.onyx.external_permissions.github.utils import form_organization_group_id
from ee.onyx.external_permissions.github.utils import (
form_outside_collaborators_group_id,
)
from ee.onyx.external_permissions.github.utils import get_external_access_permission
from ee.onyx.external_permissions.github.utils import get_repository_visibility
from ee.onyx.external_permissions.github.utils import GitHubVisibility
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import DocMetadata
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
def github_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
"""
Sync GitHub documents with external access permissions.
This function checks each repository for visibility/team changes and updates
document permissions accordingly without using checkpoints.
"""
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
# Initialize GitHub connector with credentials
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:
logger.error("GitHub client initialization failed")
raise ValueError("github_client is required")
# Get all repositories from GitHub API
logger.info("Fetching all repositories from GitHub API")
try:
repos = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(
github_connector.github_client
)
else:
# Single repository
repos = [
github_connector.get_github_repo(github_connector.github_client)
]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
logger.info(f"Found {len(repos)} repositories to check")
except Exception as e:
logger.error(f"Failed to fetch repositories: {e}")
raise
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
# sort order is ascending because we want to get the oldest documents first
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
sort_order=SortOrder.ASC
)
logger.info(f"Found {len(existing_docs)} documents to check")
for doc in existing_docs:
try:
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
if doc_metadata.repo not in repo_to_doc_list_map:
repo_to_doc_list_map[doc_metadata.repo] = []
repo_to_doc_list_map[doc_metadata.repo].append(doc)
except Exception as e:
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
continue
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
# Process each repository individually
for repo in repos:
try:
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
repo.full_name, []
)
if not repo_doc_list:
logger.warning(
f"No documents found for repository {repo.id} ({repo.name})"
)
continue
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
# Check if repository has any permission changes
has_changes = _check_repository_for_changes(
repo=repo,
github_client=github_connector.github_client,
current_external_group_ids=current_external_group_ids,
)
if has_changes:
logger.info(
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
)
# Get new external access permissions for this repository
new_external_access = get_external_access_permission(
repo, github_connector.github_client
)
logger.info(
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
)
# Yield updated external access for each document
for doc in repo_doc_list:
if callback:
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
yield DocExternalAccess(
doc_id=doc.id,
external_access=new_external_access,
)
else:
logger.info(
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
)
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
def _check_repository_for_changes(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository has any permission changes (visibility or team updates).
"""
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
# Check for repository visibility changes using the sample document data
if _is_repo_visibility_changed_from_groups(
repo=repo,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
return True
# Check for team membership changes if repository is private
if get_repository_visibility(
repo
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
repo=repo,
github_client=github_client,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
return True
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
return False
def _is_repo_visibility_changed_from_groups(
repo: Repository,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository visibility has changed by analyzing existing external group IDs.
Args:
repo: GitHub repository object
current_external_group_ids: List of external group IDs from existing document
Returns:
True if visibility has changed
"""
current_repo_visibility = get_repository_visibility(repo)
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
# Build expected group IDs for current visibility
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
org_group_id = None
if repo.organization:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_organization_group_id(repo.organization.id),
)
# Determine existing visibility from group IDs
has_collaborators_group = collaborators_group_id in current_external_group_ids
has_org_group = org_group_id and org_group_id in current_external_group_ids
if has_collaborators_group:
existing_repo_visibility = GitHubVisibility.PRIVATE
elif has_org_group:
existing_repo_visibility = GitHubVisibility.INTERNAL
else:
existing_repo_visibility = GitHubVisibility.PUBLIC
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
visibility_changed = existing_repo_visibility != current_repo_visibility
if visibility_changed:
logger.info(
f"Visibility changed for repo {repo.id} ({repo.name}): "
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
)
return visibility_changed
def _teams_updated_from_groups(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository team memberships have changed using existing group IDs.
"""
# Fetch current team slugs for the repository
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
logger.info(
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
)
# Build group IDs to exclude from team comparison (non-team groups)
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_outside_collaborators_group_id(repo.id),
)
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
# Extract existing team IDs from current external group IDs
existing_team_ids = set()
for group_id in current_external_group_ids:
# Skip all non-team groups, keep only team groups
if group_id not in non_team_group_ids:
existing_team_ids.add(group_id)
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
# but current_teams from API are raw team slugs, so we need to add the prefix
current_team_ids = set()
for team_slug in current_teams:
team_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=team_slug,
)
current_team_ids.add(team_group_id)
logger.info(
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
)
# Compare actual team IDs to detect changes
teams_changed = current_team_ids != existing_team_ids
if teams_changed:
logger.info(
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
f"existing={existing_team_ids}, current={current_team_ids}"
)
return teams_changed

View File

@@ -1,46 +0,0 @@
from collections.abc import Generator
from github import Repository
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.github.utils import get_external_user_group
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def github_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")
logger.info("Starting GitHub group sync...")
repos: list[Repository.Repository] = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(github_connector.github_client)
else:
# Single repository (backward compatibility)
repos = [github_connector.get_github_repo(github_connector.github_client)]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
for repo in repos:
try:
for external_group in get_external_user_group(
repo, github_connector.github_client
):
logger.info(f"External group: {external_group}")
yield external_group
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")

View File

@@ -1,488 +0,0 @@
from collections.abc import Callable
from enum import Enum
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from github import Github
from github import RateLimitExceededException
from github.GithubException import GithubException
from github.NamedUser import NamedUser
from github.Organization import Organization
from github.PaginatedList import PaginatedList
from github.Repository import Repository
from github.Team import Team
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GitHubVisibility(Enum):
"""GitHub repository visibility options."""
PUBLIC = "public"
PRIVATE = "private"
INTERNAL = "internal"
MAX_RETRY_COUNT = 3
T = TypeVar("T")
# Higher-order function to wrap GitHub operations with retry and exception handling
def _run_with_retry(
operation: Callable[[], T],
description: str,
github_client: Github,
retry_count: int = 0,
) -> Optional[T]:
"""Execute a GitHub operation with retry on rate limit and exception handling."""
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
try:
result = operation()
logger.debug(f"Operation '{description}' completed successfully")
return result
except RateLimitExceededException:
if retry_count < MAX_RETRY_COUNT:
sleep_after_rate_limit_exception(github_client)
logger.warning(
f"Rate limit exceeded while {description}. Retrying... "
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
)
return _run_with_retry(
operation, description, github_client, retry_count + 1
)
else:
error_msg = f"Max retries exceeded for {description}"
logger.exception(error_msg)
raise RuntimeError(error_msg)
except GithubException as e:
logger.warning(f"GitHub API error during {description}: {e}")
return None
except Exception as e:
logger.exception(f"Unexpected error during {description}: {e}")
return None
class UserInfo(BaseModel):
"""Represents a GitHub user with their basic information."""
login: str
name: Optional[str] = None
email: Optional[str] = None
class TeamInfo(BaseModel):
"""Represents a GitHub team with its members."""
name: str
slug: str
members: List[UserInfo]
def _fetch_organization_members(
github_client: Github, org_name: str, retry_count: int = 0
) -> List[UserInfo]:
"""Fetch all organization members including owners and regular members."""
org_members: List[UserInfo] = []
logger.info(f"Fetching organization members for {org_name}")
org = _run_with_retry(
lambda: github_client.get_organization(org_name),
f"get organization {org_name}",
github_client,
)
if not org:
logger.error(f"Failed to fetch organization {org_name}")
raise RuntimeError(f"Failed to fetch organization {org_name}")
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: org.get_members(filter_="all"),
f"get members for organization {org_name}",
github_client,
)
or []
)
for member in member_objs:
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
org_members.append(user_info)
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
return org_members
def _fetch_repository_teams_detailed(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[TeamInfo]:
"""Fetch teams with access to the repository and their members."""
teams_data: List[TeamInfo] = []
logger.info(f"Fetching teams for repository {repo.full_name}")
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
logger.info(
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
)
members: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: team.get_members(),
f"get members for team {team.name}",
github_client,
)
or []
)
team_members = []
for m in members:
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
team_members.append(user_info)
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
teams_data.append(team_info)
logger.info(f"Team {team.name} has {len(team_members)} members")
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
return teams_data
def fetch_repository_team_slugs(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[str]:
"""Fetch team slugs with access to the repository."""
logger.info(f"Fetching team slugs for repository {repo.full_name}")
teams_data: List[str] = []
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
teams_data.append(team.slug)
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
return teams_data
def _get_collaborators_and_outside_collaborators(
github_client: Github,
repo: Repository,
) -> Tuple[List[UserInfo], List[UserInfo]]:
"""Fetch and categorize collaborators into regular and outside collaborators."""
collaborators: List[UserInfo] = []
outside_collaborators: List[UserInfo] = []
logger.info(f"Fetching collaborators for repository {repo.full_name}")
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: repo.get_collaborators(),
f"get collaborators for repository {repo.full_name}",
github_client,
)
or []
)
for collaborator in repo_collaborators:
is_outside = False
# Check if collaborator is outside the organization
if repo.organization:
org: Organization | None = _run_with_retry(
lambda: github_client.get_organization(repo.organization.login),
f"get organization {repo.organization.login}",
github_client,
)
if org is not None:
org_obj = org
membership = _run_with_retry(
lambda: org_obj.has_in_members(collaborator),
f"check membership for {collaborator.login} in org {org_obj.login}",
github_client,
)
is_outside = membership is not None and not membership
info = UserInfo(
login=collaborator.login, name=collaborator.name, email=collaborator.email
)
if repo.organization and is_outside:
outside_collaborators.append(info)
else:
collaborators.append(info)
logger.info(
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
)
return collaborators, outside_collaborators
def form_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for repository collaborators."""
if not repository_id:
logger.exception("Repository ID is required to generate collaborators group ID")
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_collaborators"
return group_id
def form_organization_group_id(organization_id: int) -> str:
"""Generate group ID for organization using organization ID."""
if not organization_id:
logger.exception(
"Organization ID is required to generate organization group ID"
)
raise ValueError("Organization ID must be set to generate group ID.")
group_id = f"{organization_id}_organization"
return group_id
def form_outside_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for outside collaborators."""
if not repository_id:
logger.exception(
"Repository ID is required to generate outside collaborators group ID"
)
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_outside_collaborators"
return group_id
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
"""
Get the visibility of a repository.
Returns GitHubVisibility enum member.
"""
if hasattr(repo, "visibility"):
visibility = repo.visibility
logger.info(
f"Repository {repo.full_name} visibility from attribute: {visibility}"
)
try:
return GitHubVisibility(visibility)
except ValueError:
logger.warning(
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
)
return GitHubVisibility.PRIVATE
logger.info(f"Repository {repo.full_name} is private")
return GitHubVisibility.PRIVATE
def get_external_access_permission(
repo: Repository, github_client: Github, add_prefix: bool = False
) -> ExternalAccess:
"""
Get the external access permission for a repository.
Uses group-based permissions for efficiency and scalability.
add_prefix: When this method is called during the initial permission sync via the connector,
the group ID isn't prefixed with the source while inserting the document record.
So in that case, set add_prefix to True, allowing the method itself to handle
prefixing. However, when the same method is invoked from doc_sync, our system
already adds the prefix to the group ID while processing the ExternalAccess object.
"""
# We maintain collaborators, and outside collaborators as two separate groups
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
# forcing full permission syncs for all documents every time, which is inefficient.
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PUBLIC:
logger.info(
f"Repository {repo.full_name} is public - allowing access to all users"
)
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
elif repo_visibility == GitHubVisibility.PRIVATE:
logger.info(
f"Repository {repo.full_name} is private - setting up restricted access"
)
collaborators_group_id = form_collaborators_group_id(repo.id)
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
if add_prefix:
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=collaborators_group_id,
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=outside_collaborators_group_id,
)
group_ids = {collaborators_group_id, outside_collaborators_group_id}
team_slugs = fetch_repository_team_slugs(repo, github_client)
if add_prefix:
team_slugs = [
build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=slug,
)
for slug in team_slugs
]
group_ids.update(team_slugs)
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
else:
# Internal repositories - accessible to organization members
logger.info(
f"Repository {repo.full_name} is internal - accessible to org members"
)
org_group_id = form_organization_group_id(repo.organization.id)
if add_prefix:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=org_group_id,
)
group_ids = {org_group_id}
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
def get_external_user_group(
repo: Repository, github_client: Github
) -> list[ExternalUserGroup]:
"""
Get the external user group for a repository.
Creates ExternalUserGroup objects with actual user emails for each permission group.
"""
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PRIVATE:
logger.info(f"Processing private repository {repo.full_name}")
collaborators, outside_collaborators = (
_get_collaborators_and_outside_collaborators(github_client, repo)
)
teams = _fetch_repository_teams_detailed(repo, github_client)
external_user_groups = []
user_emails = set()
for collab in collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Collaborator {collab.login} has no email")
if user_emails:
collaborators_group = ExternalUserGroup(
id=form_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(collaborators_group)
logger.info(f"Created collaborators group with {len(user_emails)} emails")
# Create group for outside collaborators
user_emails = set()
for collab in outside_collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Outside collaborator {collab.login} has no email")
if user_emails:
outside_collaborators_group = ExternalUserGroup(
id=form_outside_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(outside_collaborators_group)
logger.info(
f"Created outside collaborators group with {len(user_emails)} emails"
)
# Create groups for teams
for team in teams:
user_emails = set()
for member in team.members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Team member {member.login} has no email")
if user_emails:
team_group = ExternalUserGroup(
id=team.slug,
user_emails=list(user_emails),
)
external_user_groups.append(team_group)
logger.info(
f"Created team group {team.name} with {len(user_emails)} emails"
)
logger.info(
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
)
return external_user_groups
if repo_visibility == GitHubVisibility.INTERNAL:
logger.info(f"Processing internal repository {repo.full_name}")
org_group_id = form_organization_group_id(repo.organization.id)
org_members = _fetch_organization_members(
github_client, repo.organization.login
)
user_emails = set()
for member in org_members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Org member {member.login} has no email")
org_group = ExternalUserGroup(
id=org_group_id,
user_emails=list(user_emails),
)
logger.info(
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
)
return [org_group]
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
return []

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from datetime import timezone
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -36,7 +35,6 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
get_permissions_by_ids,
)
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
@@ -41,28 +40,8 @@ def _get_slim_doc_generator(
)
def _merge_permissions_lists(
permission_lists: list[list[GoogleDrivePermission]],
) -> list[GoogleDrivePermission]:
"""
Merge a list of permission lists into a single list of permissions.
"""
seen_permission_ids: set[str] = set()
merged_permissions: list[GoogleDrivePermission] = []
for permission_list in permission_lists:
for permission in permission_list:
if permission.id not in seen_permission_ids:
merged_permissions.append(permission)
seen_permission_ids.add(permission.id)
return merged_permissions
def get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType,
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
@@ -83,28 +62,11 @@ def get_external_access_for_raw_gdrive_file(
GoogleDrivePermission.from_drive_permission(p) for p in permissions
]
elif permission_ids:
def _get_permissions(
drive_service: GoogleDriveService,
) -> list[GoogleDrivePermission]:
return get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
permissions_list = _get_permissions(
retriever_drive_service or admin_drive_service
permissions_list = get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
logger.warning(
f"Failed to get all permissions for file {doc_id} with retriever service, "
"trying admin service"
)
backup_permissions_list = _get_permissions(admin_drive_service)
permissions_list = _merge_permissions_lists(
[permissions_list, backup_permissions_list]
)
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()
@@ -170,7 +132,6 @@ def get_external_access_for_raw_gdrive_file(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from googleapiclient.errors import HttpError # type: ignore
from pydantic import BaseModel
@@ -44,17 +42,11 @@ def _get_all_folders(
TODO: tweak things so we can fetch deltas.
"""
MAX_FAILED_PERCENTAGE = 0.5
all_folders: list[FolderInfo] = []
seen_folder_ids: set[str] = set()
def _get_all_folders_for_user(
google_drive_connector: GoogleDriveConnector,
skip_folders_without_permissions: bool,
user_email: str,
) -> None:
"""Helper to get folders for a specific user + update shared seen_folder_ids"""
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
drive_service = get_drive_service(
google_drive_connector.creds,
user_email,
@@ -104,61 +96,9 @@ def _get_all_folders(
)
)
failed_count = 0
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
try:
_get_all_folders_for_user(
google_drive_connector, skip_folders_without_permissions, user_email
)
except Exception:
logger.exception(f"Error getting folders for user {user_email}")
failed_count += 1
if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails):
raise RuntimeError("Too many failed folder fetches during group sync")
return all_folders
def _drive_folder_to_onyx_group(
folder: FolderInfo,
group_email_to_member_emails_map: dict[str, list[str]],
) -> ExternalUserGroup:
"""
Converts a folder into an Onyx group.
"""
anyone_can_access = False
folder_member_emails: set[str] = set()
for permission in folder.permissions:
if permission.type == PermissionType.USER:
if permission.email_address is None:
logger.warning(
f"User email is None for folder {folder.id} permission {permission}"
)
continue
folder_member_emails.add(permission.email_address)
elif permission.type == PermissionType.GROUP:
if permission.email_address not in group_email_to_member_emails_map:
logger.warning(
f"Group email {permission.email_address} for folder {folder.id} "
"not found in group_email_to_member_emails_map"
)
continue
folder_member_emails.update(
group_email_to_member_emails_map[permission.email_address]
)
elif permission.type == PermissionType.ANYONE:
anyone_can_access = True
return ExternalUserGroup(
id=folder.id,
user_emails=list(folder_member_emails),
gives_anyone_access=anyone_can_access,
)
"""Individual Shared Drive / My Drive Permission Sync"""
@@ -227,29 +167,7 @@ def _get_drive_members(
return drive_id_to_members_map
def _drive_member_map_to_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, list[str]],
) -> Generator[ExternalUserGroup, None, None]:
"""The `user_emails` for the Shared Drive should be all individuals in the
Shared Drive + the union of all flattened group emails."""
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
drive_member_emails: set[str] = user_emails
for group_email in group_emails:
if group_email not in group_email_to_member_emails_map:
logger.warning(
f"Group email {group_email} for drive {drive_id} not found in "
"group_email_to_member_emails_map"
)
continue
drive_member_emails.update(group_email_to_member_emails_map[group_email])
yield ExternalUserGroup(
id=drive_id,
user_emails=list(drive_member_emails),
)
def _get_all_google_groups(
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
@@ -267,28 +185,6 @@ def _get_all_google_groups(
return group_emails
def _google_group_to_onyx_group(
admin_service: AdminService,
group_email: str,
) -> ExternalUserGroup:
"""
This maps google group emails to their member emails.
"""
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email),nextPageToken",
):
group_member_emails.add(member["email"])
return ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
@@ -386,7 +282,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
@@ -400,27 +296,26 @@ def gdrive_group_sync(
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
# Get all group emails
all_group_emails = _get_all_google_groups(
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
# Each google group is an Onyx group, yield those
group_email_to_member_emails_map: dict[str, list[str]] = {}
for group_email in all_group_emails:
onyx_group = _google_group_to_onyx_group(admin_service, group_email)
group_email_to_member_emails_map[group_email] = onyx_group.user_emails
yield onyx_group
# Each drive is a group, yield those
for onyx_group in _drive_member_map_to_onyx_groups(
drive_id_to_members_map, group_email_to_member_emails_map
):
yield onyx_group
# Get all folder permissions
folder_info = _get_all_folders(
google_drive_connector=google_drive_connector,
skip_folders_without_permissions=True,
)
for folder in folder_info:
yield _drive_folder_to_onyx_group(folder, group_email_to_member_emails_map)
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
folder_info=folder_info,
)
return onyx_groups

View File

@@ -1,36 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.jira.connector import JiraConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
JIRA_DOC_SYNC_TAG = "jira_doc_sync"
def jira_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
jira_connector = JiraConnector(
**cc_pair.connector.connector_specific_config,
)
jira_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.JIRA,
slim_connector=jira_connector,
label=JIRA_DOC_SYNC_TAG,
)

View File

@@ -1,25 +0,0 @@
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic.alias_generators import to_camel
Holder = dict[str, Any]
class Permission(BaseModel):
id: int
permission: str
holder: Holder | None
class User(BaseModel):
account_id: str
email_address: str
display_name: str
active: bool
model_config = ConfigDict(
alias_generator=to_camel,
)

View File

@@ -1,209 +0,0 @@
from collections import defaultdict
from jira import JIRA
from jira.resources import PermissionScheme
from pydantic import ValidationError
from ee.onyx.external_permissions.jira.models import Holder
from ee.onyx.external_permissions.jira.models import Permission
from ee.onyx.external_permissions.jira.models import User
from onyx.access.models import ExternalAccess
from onyx.utils.logger import setup_logger
HolderMap = dict[str, list[Holder]]
logger = setup_logger()
def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
"""
A "Holder" in JIRA is a person / entity who "holds" the corresponding permission.
It can have different types. They can be one of (but not limited to):
- user (an explicitly whitelisted user)
- projectRole (for project level "roles")
- reporter (the reporter of an issue)
A "Holder" usually has following structure:
- `{ "type": "user", "value": "$USER_ID", "user": { .. }, .. }`
- `{ "type": "projectRole", "value": "$PROJECT_ID", .. }`
When we fetch the PermissionSchema from JIRA, we retrieve a list of "Holder"s.
The list of "Holder"s can have multiple "Holder"s of the same type in the list (e.g., you can have two `"type": "user"`s in
there, each corresponding to a different user).
This function constructs a map of "Holder" types to a list of the "Holder"s which contained that type.
Returns:
A dict from the "Holder" type to the actual "Holder" instance.
Example:
```
{
"user": [
{ "type": "user", "value": "10000", "user": { .. }, .. },
{ "type": "user", "value": "10001", "user": { .. }, .. },
],
"projectRole": [
{ "type": "projectRole", "value": "10010", .. },
{ "type": "projectRole", "value": "10011", .. },
],
"applicationRole": [
{ "type": "applicationRole" },
],
..
}
```
"""
holder_map: defaultdict[str, list[Holder]] = defaultdict(list)
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
# We only care about ability to browse through projects + issues (not other permissions such as read/write).
if permission.permission != "BROWSE_PROJECTS":
continue
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warn(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warn(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue
holder_map[type].append(permission.holder)
return holder_map
def _get_user_emails(user_holders: list[Holder]) -> list[str]:
emails = []
for user_holder in user_holders:
if "user" not in user_holder:
continue
raw_user_dict = user_holder["user"]
try:
user_model = User.model_validate(raw_user_dict)
except ValidationError:
logger.error(
"Expected to be able to serialize the raw-user-dict into an instance of `User`, but validation failed;"
f"{raw_user_dict=}"
)
continue
emails.append(user_model.email_address)
return emails
def _get_user_emails_from_project_roles(
jira_client: JIRA,
jira_project: str,
project_role_holders: list[Holder],
) -> list[str]:
# NOTE (@raunakab) a `parallel_yield` may be helpful here...?
roles = [
jira_client.project_role(project=jira_project, id=project_role_holder["value"])
for project_role_holder in project_role_holders
if "value" in project_role_holder
]
emails = []
for role in roles:
if not hasattr(role, "actors"):
continue
for actor in role.actors:
if not hasattr(actor, "actorUser") or not hasattr(
actor.actorUser, "accountId"
):
continue
user = jira_client.user(id=actor.actorUser.accountId)
if not hasattr(user, "accountType") or user.accountType != "atlassian":
continue
if not hasattr(user, "emailAddress"):
msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}"
if hasattr(user, "displayName"):
msg += f" {actor.displayName=}"
logger.warn(msg)
continue
emails.append(user.emailAddress)
return emails
def _build_external_access_from_holder_map(
jira_client: JIRA, jira_project: str, holder_map: HolderMap
) -> ExternalAccess:
"""
# Note:
If the `holder_map` contains an instance of "anyone", then this is a public JIRA project.
Otherwise, we fetch the "projectRole"s (i.e., the user-groups in JIRA speak), and the user emails.
"""
if "anyone" in holder_map:
return ExternalAccess(
external_user_emails=set(), external_user_group_ids=set(), is_public=True
)
user_emails = (
_get_user_emails(user_holders=holder_map["user"])
if "user" in holder_map
else []
)
project_role_user_emails = (
_get_user_emails_from_project_roles(
jira_client=jira_client,
jira_project=jira_project,
project_role_holders=holder_map["projectRole"],
)
if "projectRole" in holder_map
else []
)
external_user_emails = set(user_emails + project_role_user_emails)
return ExternalAccess(
external_user_emails=external_user_emails,
external_user_group_ids=set(),
is_public=False,
)
def get_project_permissions(
jira_client: JIRA,
jira_project: str,
) -> ExternalAccess | None:
project_permissions: PermissionScheme = jira_client.project_permissionscheme(
project=jira_project
)
if not hasattr(project_permissions, "permissions"):
return None
if not isinstance(project_permissions.permissions, list):
return None
holder_map = _build_holder_map(permissions=project_permissions.permissions)
return _build_external_access_from_holder_map(
jira_client=jira_client, jira_project=jira_project, holder_map=holder_map
)

View File

@@ -5,8 +5,6 @@ from typing import Protocol
from typing import TYPE_CHECKING
from onyx.context.search.models import InferenceChunk
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
@@ -17,34 +15,14 @@ if TYPE_CHECKING:
class FetchAllDocumentsFunction(Protocol):
"""Protocol for a function that fetches documents for a connector credential pair.
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
This protocol defines the interface for functions that retrieve documents
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
sort_order: SortOrder | None,
) -> list[DocumentRow]:
def __call__(self) -> list[str]:
"""
Fetches documents for a connector credential pair.
"""
...
Returns a list of document IDs for a connector credential pair.
class FetchAllDocumentsIdsFunction(Protocol):
"""Protocol for a function that fetches document IDs for a connector credential pair.
This protocol defines the interface for functions that retrieve document IDs
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
) -> list[str]:
"""
Fetches document IDs for a connector credential pair.
This is typically used to determine which documents should no longer be
accessible during the document sync process.
"""
...
@@ -54,7 +32,6 @@ DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
],
Generator["DocExternalAccess", None, None],
@@ -62,10 +39,10 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str, # tenant_id
"ConnectorCredentialPair", # cc_pair
str,
"ConnectorCredentialPair",
],
Generator["ExternalUserGroup", None, None],
list["ExternalUserGroup"],
]
# list of chunks to be censored and the user email. returns censored chunks

View File

@@ -3,7 +3,7 @@ from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_s
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
@@ -22,7 +22,7 @@ def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
for every single chunk.
"""
all_censoring_enabled_sources = get_all_censoring_enabled_sources()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source

View File

@@ -10,7 +10,7 @@ from ee.onyx.external_permissions.salesforce.utils import (
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -44,7 +44,7 @@ def _get_objects_access_for_user_email_from_salesforce(
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
@@ -217,7 +217,7 @@ def censor_salesforce_chunks(
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,

View File

@@ -3,7 +3,6 @@ from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
@@ -131,7 +130,6 @@ def _get_slack_document_access(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -7,30 +7,21 @@ from pydantic import BaseModel
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from ee.onyx.external_permissions.github.doc_sync import github_doc_sync
from ee.onyx.external_permissions.github.group_sync import github_group_sync
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
from onyx.configs.constants import DocumentSource
if TYPE_CHECKING:
@@ -68,7 +59,6 @@ class SyncConfig(BaseModel):
def mock_doc_sync(
cc_pair: "ConnectorCredentialPair",
fetch_all_docs_fn: FetchAllDocumentsFunction,
fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: Optional["IndexingHeartbeatInterface"],
) -> Generator["DocExternalAccess", None, None]:
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
@@ -100,21 +90,15 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
group_sync_is_cc_pair_agnostic=True,
),
),
DocumentSource.JIRA: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=JIRA_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=jira_doc_sync,
initial_index_should_sync=True,
),
),
# Groups are not needed for Slack.
# All channel access is done at the individual user level.
DocumentSource.SLACK: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=slack_doc_sync,
initial_index_should_sync=True,
),
# groups are not needed for Slack. All channel access is done at the
# individual user level
group_sync_config=None,
),
DocumentSource.GMAIL: SyncConfig(
doc_sync_config=DocSyncConfig(
@@ -123,18 +107,6 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=False,
),
),
DocumentSource.GITHUB: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=github_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=github_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
DocumentSource.SALESFORCE: SyncConfig(
censoring_config=CensoringConfig(
chunk_censoring_func=censor_salesforce_chunks,
@@ -147,15 +119,6 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=True,
),
),
# Groups are not needed for Teams.
# All channel access is done at the individual user level.
DocumentSource.TEAMS: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=TEAMS_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=teams_doc_sync,
initial_index_should_sync=True,
),
),
}

View File

@@ -1,37 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.teams.connector import TeamsConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
def teams_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
teams_connector = TeamsConnector(
**cc_pair.connector.connector_specific_config,
)
teams_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.TEAMS,
slim_connector=teams_connector,
label=TEAMS_DOC_SYNC_LABEL,
)

View File

@@ -1,83 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generic_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
A convenience function for performing a generic document synchronization.
Notes:
A generic doc sync includes:
- fetching existing docs
- fetching *all* new (slim) docs
- yielding external-access permissions for existing docs which do not exist in the newly fetched slim-docs set (with their
`external_access` set to "private")
- yielding external-access permissions for newly fetched docs
Returns:
A `Generator` which yields existing and newly fetched external-access permissions.
"""
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:
if callback.should_stop():
raise RuntimeError(f"{label}: Stop signal detected")
callback.progress(label, 1)
for doc in doc_batch:
if not doc.external_access:
raise RuntimeError(
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
)
newly_fetched_doc_ids.add(doc.id)
yield DocExternalAccess(
doc_id=doc.id,
external_access=doc.external_access,
)
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn()
missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids
if not missing_doc_ids:
return
logger.warning(
f"Found {len(missing_doc_ids)=} documents that are in the DB but not present in fetch. Making them inaccessible."
)
for missing_id in missing_doc_ids:
logger.warning(f"Removing access for {missing_id=}")
yield DocExternalAccess(
doc_id=missing_id,
external_access=ExternalAccess.empty(),
)
logger.info(f"Finished {doc_source} doc sync")

View File

@@ -19,7 +19,7 @@ from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
router = APIRouter(prefix="/analytics")

View File

@@ -17,7 +17,7 @@ from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
)
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client

View File

@@ -26,9 +26,9 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -134,19 +134,19 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, is_logotype=is_logotype)
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store()
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
logger.exception("Faield to fetch logo file")
raise HTTPException(
status_code=404,
detail="No logo file found",
@@ -157,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store()
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,6 +6,7 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -98,7 +99,9 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -126,13 +129,13 @@ def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
content=content,
display_name=display_name,
file_origin=FileOrigin.OTHER,
file_type=file_type,
file_id=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
)
return True

View File

@@ -13,7 +13,7 @@ from ee.onyx.db.standard_answer import remove_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer_category
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.server.manage.models import StandardAnswer
from onyx.server.manage.models import StandardAnswerCategory

View File

@@ -11,7 +11,7 @@ from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

View File

@@ -12,10 +12,10 @@ from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -25,12 +25,12 @@ from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -33,11 +33,11 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class GoogleDriveOAuth:

View File

@@ -17,11 +17,11 @@ from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class SlackOAuth:

View File

@@ -1,6 +1,5 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
@@ -41,7 +40,7 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
@@ -74,7 +73,6 @@ def _get_final_context_doc_indices(
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
@@ -218,8 +216,6 @@ def _convert_packet_stream_to_response(
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
@@ -241,36 +237,13 @@ def handle_simplified_chat_message(
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
# Handle chat session creation if chat_session_id is not provided
if chat_message_req.chat_session_id is None:
if chat_message_req.persona_id is None:
raise HTTPException(
status_code=400,
detail="Either chat_session_id or persona_id must be provided",
)
# Create a new chat session with the provided persona_id
try:
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave empty for simple API
user_id=user.id if user else None,
persona_id=chat_message_req.persona_id,
)
chat_session_id = new_chat_session.id
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
else:
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_session_id, db_session=db_session
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
if (
@@ -285,7 +258,7 @@ def handle_simplified_chat_message(
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session_id,
chat_session_id=chat_message_req.chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
@@ -310,7 +283,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session_id)
return _convert_packet_stream_to_response(packets)
@router.post("/send-message-simple-with-history")
@@ -430,4 +403,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session.id)
return _convert_packet_stream_to_response(packets)

View File

@@ -41,13 +41,11 @@ class DocumentSearchRequest(ChunkContext):
class BasicCreateChatMessageRequest(ChunkContext):
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
"""Before creating messages, be sure to create a chat_session and get an id
Note, for simplicity this option only allows for a single linear chain of messages
"""
chat_session_id: UUID | None = None
# Optional persona_id to create a new chat session if chat_session_id is not provided
persona_id: int | None = None
chat_session_id: UUID
# New message contents
message: str
# Defaults to using retrieval with no additional filters
@@ -64,12 +62,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
raise ValueError("Either chat_session_id or persona_id must be provided")
return self
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -179,9 +171,6 @@ class ChatBasicResponse(BaseModel):
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -31,7 +31,7 @@ from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import get_prompt_by_id
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit

View File

@@ -37,11 +37,11 @@ from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.enums import TaskStatus
from onyx.db.file_record import get_query_history_export_files
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.db.pg_file_store import get_query_history_export_files
from onyx.db.tasks import get_task_with_id
from onyx.db.tasks import register_task
from onyx.file_store.file_store import get_default_file_store
@@ -49,7 +49,6 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.utils.threadpool_concurrency import parallel_yield
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter()
@@ -335,7 +334,6 @@ def start_query_history_export(
"start": start,
"end": end,
"start_time": start_time,
"tenant_id": get_current_tenant_id(),
},
)
@@ -358,11 +356,11 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
file_id=report_name,
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
)
@@ -385,9 +383,9 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
has_file = file_store.has_file(
file_id=report_name,
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import SessionType
from onyx.db.enums import TaskStatus
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import FileRecord
from onyx.db.models import PGFileStore
from onyx.db.models import TaskQueueState
@@ -254,7 +254,7 @@ class QueryHistoryExport(BaseModel):
@classmethod
def from_file(
cls,
file: FileRecord,
file: PGFileStore,
) -> "QueryHistoryExport":
if not file.file_metadata or not isinstance(file.file_metadata, dict):
raise RuntimeError(
@@ -262,7 +262,7 @@ class QueryHistoryExport(BaseModel):
)
metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata))
task_id = extract_task_id_from_query_history_report_name(file.file_id)
task_id = extract_task_id_from_query_history_report_name(file.file_name)
return cls(
task_id=task_id,

View File

@@ -14,7 +14,7 @@ from ee.onyx.db.usage_export import get_usage_report_data
from ee.onyx.db.usage_export import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(report_name)
file = get_usage_report_data(db_session, report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -62,16 +62,17 @@ def generate_chat_messages_report(
]
)
# after writing seek to beginning of buffer
# after writing seek to begining of buffer
temp_file.seek(0)
file_id = file_store.save_file(
file_store.save_file(
file_name=file_name,
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_type="text/csv",
)
return file_id
return file_name
def generate_user_report(
@@ -96,14 +97,15 @@ def generate_user_report(
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
temp_file.seek(0)
file_id = file_store.save_file(
file_store.save_file(
file_name=file_name,
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_type="text/csv",
)
return file_id
return file_name
def create_new_usage_report(
@@ -112,18 +114,18 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
messages_file_id = generate_chat_messages_report(
messages_filename = generate_chat_messages_report(
db_session, file_store, report_id, period
)
users_file_id = generate_user_report(db_session, file_store, report_id)
users_filename = generate_user_report(db_session, file_store, report_id)
with tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) as zip_buffer:
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
# write messages
chat_messages_tmpfile = file_store.read_file(
messages_file_id, mode="b", use_tempfile=True
messages_filename, mode="b", use_tempfile=True
)
zip_file.writestr(
"chat_messages.csv",
@@ -132,7 +134,7 @@ def create_new_usage_report(
# write users
users_tmpfile = file_store.read_file(
users_file_id, mode="b", use_tempfile=True
users_filename, mode="b", use_tempfile=True
)
zip_file.writestr("users.csv", users_tmpfile.read())
@@ -144,11 +146,11 @@ def create_new_usage_report(
f"_{report_id}_usage_report.zip"
)
file_store.save_file(
file_name=report_name,
content=zip_buffer,
display_name=report_name,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="application/zip",
file_id=report_name,
)
# add report after zip file is written

View File

@@ -27,9 +27,9 @@ from onyx.auth.users import get_user_manager
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.utils.logger import setup_logger

View File

@@ -19,7 +19,7 @@ from ee.onyx.server.enterprise_settings.store import (
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(logo_path: str | None) -> None:
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(file=logo_path)
upload_logo(db_session=db_session, file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -235,7 +235,7 @@ def seed_db() -> None:
logger.debug("No seeding configuration file passed")
return
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)
if seed_config.personas is not None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(seed_config.seeded_logo_path)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -10,7 +10,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger

View File

@@ -18,7 +18,7 @@ from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id

View File

@@ -10,12 +10,10 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
@@ -49,26 +47,6 @@ def gate_product(
return ProductGatingResponse(updated=False, error=str(e))
@router.post("/product-gating/full-sync")
def gate_product_full_sync(
product_gating_request: ProductGatingFullSyncRequest,
_: None = Depends(control_plane_dep),
) -> ProductGatingResponse:
"""
Bulk operation to overwrite the entire gated tenant set.
This replaces all currently gated tenants with the provided list.
Gated tenants are not available to access the product and will be
directed to the billing page when their subscription has ended.
"""
try:
overwrite_full_gated_set(product_gating_request.gated_tenant_ids)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate products during full sync")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),

View File

@@ -19,10 +19,6 @@ class ProductGatingRequest(BaseModel):
application_status: ApplicationStatus
class ProductGatingFullSyncRequest(BaseModel):
gated_tenant_ids: list[str]
class SubscriptionStatusResponse(BaseModel):
subscribed: bool

View File

@@ -16,6 +16,10 @@ logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
@@ -42,25 +46,6 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
raise
def overwrite_full_gated_set(tenant_ids: list[str]) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
pipeline = redis_client.pipeline()
# using pipeline doesn't automatically add the tenant_id prefix
full_gated_set_key = f"{ONYX_CLOUD_TENANT_ID}:{GATED_TENANTS_KEY}"
# Clear the existing set
pipeline.delete(full_gated_set_key)
# Add all tenant IDs to the set and set their status
for tenant_id in tenant_ids:
pipeline.sadd(full_gated_set_key, tenant_id)
# Execute all commands at once
pipeline.execute()
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -28,8 +28,8 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider

View File

@@ -8,8 +8,8 @@ from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
@@ -34,7 +34,7 @@ def run_alembic_migrations(schema_name: str) -> None:
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail

View File

@@ -5,8 +5,8 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.utils.logger import setup_logger

View File

@@ -9,7 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_user_token_rate_limits
from onyx.db.token_limit import insert_user_token_rate_limit

View File

@@ -16,7 +16,7 @@ from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger

View File

@@ -40,30 +40,6 @@ class ExternalAccess:
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)
@classmethod
def public(cls) -> "ExternalAccess":
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
@classmethod
def empty(cls) -> "ExternalAccess":
"""
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
This effectively makes the document in question "private" or inaccessible to anyone else.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
"""
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -78,7 +78,7 @@ def should_continue(state: BasicState) -> str:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -87,7 +87,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,

View File

@@ -4,7 +4,7 @@ from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,

View File

@@ -111,7 +111,7 @@ def answer_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -121,7 +121,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -238,7 +238,7 @@ def agent_search_graph_builder() -> StateGraph:
if __name__ == "__main__":
pass
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -246,7 +246,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request

View File

@@ -109,7 +109,7 @@ def answer_refined_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -119,7 +119,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",

View File

@@ -131,7 +131,7 @@ def expanded_retrieval_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -142,7 +142,7 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -24,7 +24,7 @@ from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from onyx.utils.timing import log_function_time
@@ -60,7 +60,7 @@ def rerank_documents(
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)

Some files were not shown because too many files have changed in this diff Show More