mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
6 Commits
hackathon-
...
KG_dev_cop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e38898c7a | ||
|
|
ce6a597eca | ||
|
|
d251ba40ae | ||
|
|
26395d81c9 | ||
|
|
e1a3e11ec9 | ||
|
|
e013711664 |
@@ -1,86 +0,0 @@
|
||||
name: External Dependency Unit Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
# AWS
|
||||
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
|
||||
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
|
||||
|
||||
# MinIO
|
||||
S3_ENDPOINT_URL: "http://localhost:9004"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all subdirectories in backend/tests/external_dependency_unit
|
||||
dirs=$(find backend/tests/external_dependency_unit -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | sort | jq -R -s -c 'split("\n")[:-1]')
|
||||
echo "test-dirs=$dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
external-dependency-unit-tests:
|
||||
needs: discover-test-dirs
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
|
||||
|
||||
- name: Run migrations
|
||||
run: |
|
||||
cd backend
|
||||
alembic upgrade head
|
||||
|
||||
- name: Run Tests for ${{ matrix.test-dir }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/external_dependency_unit/${{ matrix.test-dir }}
|
||||
14
.github/workflows/pr-integration-tests.yml
vendored
14
.github/workflows/pr-integration-tests.yml
vendored
@@ -16,21 +16,12 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -269,9 +260,6 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
14
.github/workflows/pr-mit-integration-tests.yml
vendored
14
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -16,20 +16,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -204,9 +195,6 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
5
.github/workflows/pr-python-checks.yml
vendored
5
.github/workflows/pr-python-checks.yml
vendored
@@ -54,6 +54,11 @@ jobs:
|
||||
cd backend
|
||||
mypy .
|
||||
|
||||
- name: Run ruff
|
||||
run: |
|
||||
cd backend
|
||||
ruff .
|
||||
|
||||
- name: Check import order with reorder-python-imports
|
||||
run: |
|
||||
cd backend
|
||||
|
||||
@@ -22,7 +22,6 @@ env:
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
@@ -50,9 +49,6 @@ env:
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
# Hubspot
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
|
||||
3
.github/workflows/pr-python-tests.yml
vendored
3
.github/workflows/pr-python-tests.yml
vendored
@@ -15,9 +15,6 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -58,9 +58,3 @@ AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ran
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
@@ -77,9 +77,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Install postgresql-client for easy manual tests
|
||||
# Install it here to avoid it being cleaned up above
|
||||
RUN apt-get update && apt-get install -y postgresql-client
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
|
||||
@@ -20,44 +20,3 @@ To run all un-applied migrations:
|
||||
To undo migrations:
|
||||
`alembic downgrade -X`
|
||||
where X is the number of migrations you want to undo from the current state
|
||||
|
||||
### Multi-tenant migrations
|
||||
|
||||
For multi-tenant deployments, you can use additional options:
|
||||
|
||||
**Upgrade all tenants:**
|
||||
```bash
|
||||
alembic -x upgrade_all_tenants=true upgrade head
|
||||
```
|
||||
|
||||
**Upgrade specific schemas:**
|
||||
```bash
|
||||
# Single schema
|
||||
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head
|
||||
|
||||
# Multiple schemas (comma-separated)
|
||||
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head
|
||||
```
|
||||
|
||||
**Upgrade tenants within an alphabetical range:**
|
||||
```bash
|
||||
# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200)
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head
|
||||
|
||||
# Upgrade tenants starting from position 1000 alphabetically
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head
|
||||
|
||||
# Upgrade first 500 tenants alphabetically
|
||||
alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head
|
||||
```
|
||||
|
||||
**Continue on error (for batch operations):**
|
||||
```bash
|
||||
alembic -x upgrade_all_tenants=true -x continue=true upgrade head
|
||||
```
|
||||
|
||||
The tenant range filtering works by:
|
||||
1. Sorting tenant IDs alphabetically
|
||||
2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.)
|
||||
3. Filtering to the specified range of positions
|
||||
4. Non-tenant schemas (like 'public') are always included
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
@@ -21,14 +21,10 @@ from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
|
||||
TENANT_ID_PREFIX,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine import SqlEngine
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
@@ -73,67 +69,15 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filter tenant IDs by alphabetical position range.
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs to filter
|
||||
start_range: Starting position in alphabetically sorted list (1-based, inclusive)
|
||||
end_range: Ending position in alphabetically sorted list (1-based, inclusive)
|
||||
|
||||
Returns:
|
||||
Filtered list of tenant IDs in their original order
|
||||
"""
|
||||
if start_range is None and end_range is None:
|
||||
return tenant_ids
|
||||
|
||||
# Separate tenant IDs from non-tenant schemas
|
||||
tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)]
|
||||
non_tenant_schemas = [
|
||||
tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
# Sort tenant schemas alphabetically.
|
||||
# NOTE: can cause missed schemas if a schema is created in between workers
|
||||
# fetching of all tenant IDs. We accept this risk for now. Just re-running
|
||||
# the migration will fix the issue.
|
||||
sorted_tenant_schemas = sorted(tenant_schemas)
|
||||
|
||||
# Apply range filtering (0-based indexing)
|
||||
start_idx = start_range if start_range is not None else 0
|
||||
end_idx = end_range if end_range is not None else len(sorted_tenant_schemas)
|
||||
|
||||
# Ensure indices are within bounds
|
||||
start_idx = max(0, start_idx)
|
||||
end_idx = min(len(sorted_tenant_schemas), end_idx)
|
||||
|
||||
# Get the filtered tenant schemas
|
||||
filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx]
|
||||
|
||||
# Combine with non-tenant schemas and preserve original order
|
||||
filtered_tenants = []
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas:
|
||||
filtered_tenants.append(tenant_id)
|
||||
|
||||
return filtered_tenants
|
||||
|
||||
|
||||
def get_schema_options() -> (
|
||||
tuple[bool, bool, bool, int | None, int | None, list[str] | None]
|
||||
):
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
else:
|
||||
raise ValueError(f"Invalid argument: {arg}")
|
||||
|
||||
for pair in arg.split(","):
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
@@ -141,81 +85,17 @@ def get_schema_options() -> (
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
# Tenant range filtering
|
||||
tenant_range_start = None
|
||||
tenant_range_end = None
|
||||
|
||||
if "tenant_range_start" in x_args:
|
||||
try:
|
||||
tenant_range_start = int(x_args["tenant_range_start"])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer."
|
||||
)
|
||||
|
||||
if "tenant_range_end" in x_args:
|
||||
try:
|
||||
tenant_range_end = int(x_args["tenant_range_end"])
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer."
|
||||
)
|
||||
|
||||
# Validate range
|
||||
if tenant_range_start is not None and tenant_range_end is not None:
|
||||
if tenant_range_start > tenant_range_end:
|
||||
raise ValueError(
|
||||
f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})"
|
||||
)
|
||||
|
||||
# Specific schema names filtering (replaces both schema_name and the old tenant_ids approach)
|
||||
schemas = None
|
||||
if "schemas" in x_args:
|
||||
schema_names_str = x_args["schemas"].strip()
|
||||
if schema_names_str:
|
||||
# Split by comma and strip whitespace
|
||||
schemas = [
|
||||
name.strip() for name in schema_names_str.split(",") if name.strip()
|
||||
]
|
||||
if schemas:
|
||||
logger.info(f"Specific schema names specified: {schemas}")
|
||||
|
||||
# Validate that only one method is used at a time
|
||||
range_filtering = tenant_range_start is not None or tenant_range_end is not None
|
||||
specific_filtering = schemas is not None and len(schemas) > 0
|
||||
|
||||
if range_filtering and specific_filtering:
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) "
|
||||
"and specific schema filtering (schemas) at the same time. "
|
||||
"Please use only one filtering method."
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
if upgrade_all_tenants and specific_filtering:
|
||||
raise ValueError(
|
||||
"Cannot use both upgrade_all_tenants=true and schemas at the same time. "
|
||||
"Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas."
|
||||
)
|
||||
|
||||
# If any filtering parameters are specified, we're not doing the default single schema migration
|
||||
if range_filtering:
|
||||
upgrade_all_tenants = True
|
||||
|
||||
# Validate multi-tenant requirements
|
||||
if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering:
|
||||
raise ValueError(
|
||||
"In multi-tenant mode, you must specify either upgrade_all_tenants=true "
|
||||
"or provide schemas. Cannot run default migration."
|
||||
)
|
||||
|
||||
return (
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
)
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
@@ -262,17 +142,12 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
) = get_schema_options()
|
||||
|
||||
if not schemas and not MULTI_TENANT:
|
||||
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
|
||||
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
@@ -289,50 +164,12 @@ async def run_async_migrations() -> None:
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if schemas:
|
||||
# Use specific schema names directly without fetching all tenants
|
||||
logger.info(f"Migrating specific schema names: {schemas}")
|
||||
|
||||
i_schema = 0
|
||||
num_schemas = len(schemas)
|
||||
for schema in schemas:
|
||||
i_schema += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}"
|
||||
)
|
||||
try:
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
logger.error("--continue=true is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue=true is set, continuing to next schema.")
|
||||
|
||||
elif upgrade_all_tenants:
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
filtered_tenant_schemas = filter_tenants_by_range(
|
||||
tenant_schemas, tenant_range_start, tenant_range_end
|
||||
)
|
||||
|
||||
if tenant_range_start is not None or tenant_range_end is not None:
|
||||
logger.info(
|
||||
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
|
||||
)
|
||||
logger.info(
|
||||
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
|
||||
)
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(filtered_tenant_schemas)
|
||||
for schema in filtered_tenant_schemas:
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
@@ -353,13 +190,17 @@ async def run_async_migrations() -> None:
|
||||
logger.warning("--continue=true is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
# This should not happen in the new design since we require either
|
||||
# upgrade_all_tenants=true or schemas in multi-tenant mode
|
||||
# and for non-multi-tenant mode, we should use schemas with the default schema
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
@@ -380,37 +221,10 @@ def run_migrations_offline() -> None:
|
||||
# without init_engine, subsequent engine calls fail hard intentionally
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
(
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
tenant_range_start,
|
||||
tenant_range_end,
|
||||
schemas,
|
||||
) = get_schema_options()
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if schemas:
|
||||
# Use specific schema names directly without fetching all tenants
|
||||
logger.info(f"Migrating specific schema names: {schemas}")
|
||||
|
||||
for schema in schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
elif upgrade_all_tenants:
|
||||
if upgrade_all_tenants:
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
@@ -424,19 +238,7 @@ def run_migrations_offline() -> None:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
filtered_tenant_schemas = filter_tenants_by_range(
|
||||
tenant_schemas, tenant_range_start, tenant_range_end
|
||||
)
|
||||
|
||||
if tenant_range_start is not None or tenant_range_end is not None:
|
||||
logger.info(
|
||||
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
|
||||
)
|
||||
logger.info(
|
||||
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
|
||||
)
|
||||
|
||||
for schema in filtered_tenant_schemas:
|
||||
for schema in tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
context.configure(
|
||||
url=url,
|
||||
@@ -452,12 +254,21 @@ def run_migrations_offline() -> None:
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
else:
|
||||
# This should not happen in the new design
|
||||
raise ValueError(
|
||||
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
|
||||
"or schemas for specific schemas."
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
"""rework-kg-config
|
||||
|
||||
Revision ID: 03bf8be6b53a
|
||||
Revises: 65bc6e0f8500
|
||||
Create Date: 2025-06-16 10:52:34.815335
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03bf8be6b53a"
|
||||
down_revision = "65bc6e0f8500"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# get current config
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT kg_variable_name, kg_variable_values FROM kg_config"))
|
||||
.all()
|
||||
)
|
||||
current_config_dict = {
|
||||
config.kg_variable_name: (
|
||||
config.kg_variable_values[0]
|
||||
if config.kg_variable_name
|
||||
not in ("KG_VENDOR_DOMAINS", "KG_IGNORE_EMAIL_DOMAINS")
|
||||
else config.kg_variable_values
|
||||
)
|
||||
for config in current_configs
|
||||
if config.kg_variable_values
|
||||
}
|
||||
|
||||
# not using the KGConfigSettings model here in case it changes in the future
|
||||
kg_config_settings = json.dumps(
|
||||
{
|
||||
"KG_EXPOSED": current_config_dict.get("KG_EXPOSED", False),
|
||||
"KG_ENABLED": current_config_dict.get("KG_ENABLED", False),
|
||||
"KG_VENDOR": current_config_dict.get("KG_VENDOR", None),
|
||||
"KG_VENDOR_DOMAINS": current_config_dict.get("KG_VENDOR_DOMAINS", []),
|
||||
"KG_IGNORE_EMAIL_DOMAINS": current_config_dict.get(
|
||||
"KG_IGNORE_EMAIL_DOMAINS", []
|
||||
),
|
||||
"KG_COVERAGE_START": current_config_dict.get(
|
||||
"KG_COVERAGE_START",
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
),
|
||||
"KG_MAX_COVERAGE_DAYS": current_config_dict.get("KG_MAX_COVERAGE_DAYS", 90),
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": current_config_dict.get(
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH", 2
|
||||
),
|
||||
"KG_BETA_PERSONA_ID": current_config_dict.get("KG_BETA_PERSONA_ID", None),
|
||||
}
|
||||
)
|
||||
op.execute(
|
||||
f"INSERT INTO key_value_store (key, value) VALUES ('kg_config', '{kg_config_settings}')"
|
||||
)
|
||||
|
||||
# drop kg config table
|
||||
op.drop_table("kg_config")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# get current config
|
||||
current_config_dict = {
|
||||
"KG_EXPOSED": False,
|
||||
"KG_ENABLED": False,
|
||||
"KG_VENDOR": [],
|
||||
"KG_VENDOR_DOMAINS": [],
|
||||
"KG_IGNORE_EMAIL_DOMAINS": [],
|
||||
"KG_COVERAGE_START": (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
|
||||
"KG_MAX_COVERAGE_DAYS": 90,
|
||||
"KG_MAX_PARENT_RECURSION_DEPTH": 2,
|
||||
}
|
||||
current_configs = (
|
||||
op.get_bind()
|
||||
.execute(text("SELECT value FROM key_value_store WHERE key = 'kg_config'"))
|
||||
.one_or_none()
|
||||
)
|
||||
if current_configs is not None:
|
||||
current_config_dict.update(current_configs[0])
|
||||
insert_values = [
|
||||
{
|
||||
"kg_variable_name": name,
|
||||
"kg_variable_values": (
|
||||
[str(val).lower() if isinstance(val, bool) else str(val)]
|
||||
if not isinstance(val, list)
|
||||
else val
|
||||
),
|
||||
}
|
||||
for name, val in current_config_dict.items()
|
||||
]
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
insert_values,
|
||||
)
|
||||
|
||||
op.execute("DELETE FROM key_value_store WHERE key = 'kg_config'")
|
||||
@@ -1,136 +0,0 @@
|
||||
"""update_kg_trigger_functions
|
||||
|
||||
Revision ID: 36e9220ab794
|
||||
Revises: c9e2cd766c29
|
||||
Create Date: 2025-06-22 17:33:25.833733
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "36e9220ab794"
|
||||
down_revision = "c9e2cd766c29"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _get_tenant_contextvar(session: Session) -> str:
|
||||
"""Get the current schema for the migration"""
|
||||
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
|
||||
if isinstance(current_tenant, str):
|
||||
return current_tenant
|
||||
else:
|
||||
raise ValueError("Current tenant is not a string")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
tenant_id = _get_tenant_contextvar(session)
|
||||
alphanum_pattern = r"[^a-z0-9]+"
|
||||
truncate_length = 1000
|
||||
function = "update_kg_entity_name"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
-- Set name to semantic_id if document_id is not NULL
|
||||
IF NEW.document_id IS NOT NULL THEN
|
||||
SELECT lower(semantic_id) INTO name
|
||||
FROM "{tenant_id}".document
|
||||
WHERE id = NEW.document_id;
|
||||
ELSE
|
||||
name = lower(NEW.name);
|
||||
END IF;
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity')
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
BEFORE INSERT OR UPDATE OF name
|
||||
ON "{tenant_id}".kg_entity
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION "{tenant_id}".{function}();
|
||||
"""
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
function = "update_kg_entity_name_from_doc"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
doc_name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
doc_name = lower(NEW.semantic_id);
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
doc_name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams for all entities referencing this document
|
||||
UPDATE "{tenant_id}".kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document')
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
AFTER UPDATE OF semantic_id
|
||||
ON "{tenant_id}".document
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION "{tenant_id}".{function}();
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -21,14 +21,22 @@ depends_on = None
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
# 3. Note: CONCURRENTLY was removed due to operational issues
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
@@ -44,9 +52,12 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_message_tsv
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
@@ -61,9 +72,12 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
@@ -71,9 +85,12 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
|
||||
@@ -15,7 +15,6 @@ from datetime import datetime, timedelta
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -34,8 +33,11 @@ def upgrade() -> None:
|
||||
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
print("MULTI_TENANT: ", MULTI_TENANT)
|
||||
if not MULTI_TENANT:
|
||||
|
||||
print("Single tenant mode")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
@@ -66,6 +68,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Grant usage on current schema to readonly user
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
@@ -467,11 +470,11 @@ def upgrade() -> None:
|
||||
|
||||
# Create GIN index for clustering and normalization
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
"ON kg_entity USING GIN (name public.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"ON kg_entity USING GIN (name_trigrams)"
|
||||
)
|
||||
|
||||
@@ -508,7 +511,7 @@ def upgrade() -> None:
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
|
||||
NEW.name_trigrams = public.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
@@ -553,7 +556,7 @@ def upgrade() -> None:
|
||||
UPDATE kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
|
||||
name_trigrams = public.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
@@ -573,9 +576,13 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
|
||||
# Drop all views that start with 'kg_'
|
||||
op.execute(
|
||||
"""
|
||||
@@ -625,8 +632,9 @@ def downgrade() -> None:
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
|
||||
|
||||
# Drop index
|
||||
op.execute("DROP INDEX IF EXISTS idx_kg_entity_clustering_trigrams")
|
||||
op.execute("DROP INDEX IF EXISTS idx_kg_entity_normalization_trigrams")
|
||||
op.execute("COMMIT") # Commit to allow CONCURRENTLY
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
|
||||
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
@@ -643,21 +651,6 @@ def downgrade() -> None:
|
||||
op.drop_column("document", "kg_processing_time")
|
||||
op.drop_table("kg_config")
|
||||
|
||||
# Revoke usage on current schema for the readonly user
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
@@ -678,4 +671,20 @@ def downgrade() -> None:
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
else:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""add stale column to external user group tables
|
||||
|
||||
Revision ID: 58c50ef19f08
|
||||
Revises: 7b9b952abdf6
|
||||
Create Date: 2025-06-25 14:08:14.162380
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "58c50ef19f08"
|
||||
down_revision = "7b9b952abdf6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the stale column with default value False to user__external_user_group_id
|
||||
op.add_column(
|
||||
"user__external_user_group_id",
|
||||
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Create index for efficient querying of stale rows by cc_pair_id
|
||||
op.create_index(
|
||||
"ix_user__external_user_group_id_cc_pair_id_stale",
|
||||
"user__external_user_group_id",
|
||||
["cc_pair_id", "stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create index for efficient querying of all stale rows
|
||||
op.create_index(
|
||||
"ix_user__external_user_group_id_stale",
|
||||
"user__external_user_group_id",
|
||||
["stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Add the stale column with default value False to public_external_user_group
|
||||
op.add_column(
|
||||
"public_external_user_group",
|
||||
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
|
||||
# Create index for efficient querying of stale rows by cc_pair_id
|
||||
op.create_index(
|
||||
"ix_public_external_user_group_cc_pair_id_stale",
|
||||
"public_external_user_group",
|
||||
["cc_pair_id", "stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create index for efficient querying of all stale rows
|
||||
op.create_index(
|
||||
"ix_public_external_user_group_stale",
|
||||
"public_external_user_group",
|
||||
["stale"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indices for public_external_user_group first
|
||||
op.drop_index(
|
||||
"ix_public_external_user_group_stale", table_name="public_external_user_group"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_public_external_user_group_cc_pair_id_stale",
|
||||
table_name="public_external_user_group",
|
||||
)
|
||||
|
||||
# Drop the stale column from public_external_user_group
|
||||
op.drop_column("public_external_user_group", "stale")
|
||||
|
||||
# Drop the indices for user__external_user_group_id
|
||||
op.drop_index(
|
||||
"ix_user__external_user_group_id_stale",
|
||||
table_name="user__external_user_group_id",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_user__external_user_group_id_cc_pair_id_stale",
|
||||
table_name="user__external_user_group_id",
|
||||
)
|
||||
|
||||
# Drop the stale column from user__external_user_group_id
|
||||
op.drop_column("user__external_user_group_id", "stale")
|
||||
@@ -1,41 +0,0 @@
|
||||
"""remove kg subtype from db
|
||||
|
||||
Revision ID: 65bc6e0f8500
|
||||
Revises: cec7ec36c505
|
||||
Create Date: 2025-06-13 10:04:27.705976
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "65bc6e0f8500"
|
||||
down_revision = "cec7ec36c505"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("kg_entity", "entity_class")
|
||||
op.drop_column("kg_entity", "entity_subtype")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_class")
|
||||
op.drop_column("kg_entity_extraction_staging", "entity_subtype")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_subtype", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
op.add_column(
|
||||
"kg_entity", sa.Column("entity_class", sa.String(), nullable=True, index=True)
|
||||
)
|
||||
@@ -1,318 +0,0 @@
|
||||
"""update-entities
|
||||
|
||||
Revision ID: 7b9b952abdf6
|
||||
Revises: 36e9220ab794
|
||||
Create Date: 2025-06-23 20:24:08.139201
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7b9b952abdf6"
|
||||
down_revision = "36e9220ab794"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# new entity type metadata_attribute_conversion
|
||||
new_entity_type_conversion = {
|
||||
"LINEAR": {
|
||||
"team": {"name": "team", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"priority": {
|
||||
"name": "priority",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"estimate": {
|
||||
"name": "estimate",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"started_at": {
|
||||
"name": "started_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"completed_at": {
|
||||
"name": "completed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"due_date": {
|
||||
"name": "due_date",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"creator": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignee": {
|
||||
"name": "assignee",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"JIRA": {
|
||||
"issuetype": {
|
||||
"name": "subtype",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"status": {"name": "status", "keep": True, "implication_property": None},
|
||||
"priority": {
|
||||
"name": "priority",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"project_name": {
|
||||
"name": "project",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"resolution_date": {
|
||||
"name": "completed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"duedate": {"name": "due_date", "keep": True, "implication_property": None},
|
||||
"reporter_email": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignee_email": {
|
||||
"name": "assignee",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
"key": {"name": "key", "keep": True, "implication_property": None},
|
||||
"parent": {"name": "parent", "keep": True, "implication_property": None},
|
||||
},
|
||||
"GITHUB_PR": {
|
||||
"repo": {"name": "repository", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"num_commits": {
|
||||
"name": "num_commits",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"num_files_changed": {
|
||||
"name": "num_files_changed",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"labels": {"name": "labels", "keep": True, "implication_property": None},
|
||||
"merged": {"name": "merged", "keep": True, "implication_property": None},
|
||||
"merged_at": {
|
||||
"name": "merged_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"closed_at": {
|
||||
"name": "closed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"user": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignees": {
|
||||
"name": "assignees",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"GITHUB_ISSUE": {
|
||||
"repo": {"name": "repository", "keep": True, "implication_property": None},
|
||||
"state": {"name": "state", "keep": True, "implication_property": None},
|
||||
"labels": {"name": "labels", "keep": True, "implication_property": None},
|
||||
"closed_at": {
|
||||
"name": "closed_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"user": {
|
||||
"name": "creator",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_creator_of",
|
||||
},
|
||||
},
|
||||
"assignees": {
|
||||
"name": "assignees",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "from_email",
|
||||
"implied_relationship_name": "is_assignee_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"FIREFLIES": {},
|
||||
"ACCOUNT": {},
|
||||
"OPPORTUNITY": {
|
||||
"name": {"name": "name", "keep": True, "implication_property": None},
|
||||
"stage_name": {"name": "stage", "keep": True, "implication_property": None},
|
||||
"type": {"name": "type", "keep": True, "implication_property": None},
|
||||
"amount": {"name": "amount", "keep": True, "implication_property": None},
|
||||
"fiscal_year": {
|
||||
"name": "fiscal_year",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"fiscal_quarter": {
|
||||
"name": "fiscal_quarter",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"is_closed": {
|
||||
"name": "is_closed",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"close_date": {
|
||||
"name": "close_date",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"probability": {
|
||||
"name": "close_probability",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"created_date": {
|
||||
"name": "created_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"last_modified_date": {
|
||||
"name": "updated_at",
|
||||
"keep": True,
|
||||
"implication_property": None,
|
||||
},
|
||||
"account": {
|
||||
"name": "account",
|
||||
"keep": False,
|
||||
"implication_property": {
|
||||
"implied_entity_type": "ACCOUNT",
|
||||
"implied_relationship_name": "is_account_of",
|
||||
},
|
||||
},
|
||||
},
|
||||
"VENDOR": {},
|
||||
"EMPLOYEE": {},
|
||||
}
|
||||
|
||||
current_entity_types = conn.execute(
|
||||
sa.text("SELECT id_name, attributes from kg_entity_type")
|
||||
).all()
|
||||
for entity_type, attributes in current_entity_types:
|
||||
# delete removed entity types
|
||||
if entity_type not in new_entity_type_conversion:
|
||||
op.execute(
|
||||
sa.text(f"DELETE FROM kg_entity_type WHERE id_name = '{entity_type}'")
|
||||
)
|
||||
continue
|
||||
|
||||
# update entity type attributes
|
||||
if "metadata_attributes" in attributes:
|
||||
del attributes["metadata_attributes"]
|
||||
attributes["metadata_attribute_conversion"] = new_entity_type_conversion[
|
||||
entity_type
|
||||
]
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
current_entity_types = conn.execute(
|
||||
sa.text("SELECT id_name, attributes from kg_entity_type")
|
||||
).all()
|
||||
for entity_type, attributes in current_entity_types:
|
||||
conversion = {}
|
||||
if "metadata_attribute_conversion" in attributes:
|
||||
conversion = attributes.pop("metadata_attribute_conversion")
|
||||
attributes["metadata_attributes"] = {
|
||||
attr: prop["name"] for attr, prop in conversion.items() if prop["keep"]
|
||||
}
|
||||
|
||||
attributes_str = json.dumps(attributes).replace("'", "''")
|
||||
op.execute(
|
||||
sa.text(
|
||||
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
|
||||
f"WHERE id_name = '{entity_type}'"
|
||||
),
|
||||
)
|
||||
@@ -1,312 +0,0 @@
|
||||
"""modify_file_store_for_external_storage
|
||||
|
||||
Revision ID: c9e2cd766c29
|
||||
Revises: 03bf8be6b53a
|
||||
Create Date: 2025-06-13 14:02:09.867679
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from typing import cast, Any
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from onyx.db._deprecated.pg_file_store import delete_lobj_by_id, read_lobj
|
||||
from onyx.file_store.file_store import get_s3_file_store
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c9e2cd766c29"
|
||||
down_revision = "03bf8be6b53a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
try:
|
||||
# Modify existing file_store table to support external storage
|
||||
op.rename_table("file_store", "file_record")
|
||||
|
||||
# Make lobj_oid nullable (for external storage files)
|
||||
op.alter_column("file_record", "lobj_oid", nullable=True)
|
||||
|
||||
# Add external storage columns with generic names
|
||||
op.add_column(
|
||||
"file_record", sa.Column("bucket_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"file_record", sa.Column("object_key", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
# Add timestamps for tracking
|
||||
op.add_column(
|
||||
"file_record",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"file_record",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
op.alter_column("file_record", "file_name", new_column_name="file_id")
|
||||
except Exception as e:
|
||||
if "does not exist" in str(e) or 'relation "file_store" does not exist' in str(
|
||||
e
|
||||
):
|
||||
print(
|
||||
f"Ran into error - {e}. Likely means we had a partial success in the past, continuing..."
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
print(
|
||||
"External storage configured - migrating files from PostgreSQL to external storage..."
|
||||
)
|
||||
# if we fail midway through this, we'll have a partial success. Running the migration
|
||||
# again should allow us to continue.
|
||||
_migrate_files_to_external_storage()
|
||||
print("File migration completed successfully!")
|
||||
|
||||
# Remove lobj_oid column
|
||||
op.drop_column("file_record", "lobj_oid")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Revert schema changes and migrate files from external storage back to PostgreSQL large objects."""
|
||||
|
||||
print(
|
||||
"Reverting to PostgreSQL-backed file store – migrating files from external storage …"
|
||||
)
|
||||
|
||||
# 1. Ensure `lobj_oid` exists on the current `file_record` table (nullable for now).
|
||||
op.add_column("file_record", sa.Column("lobj_oid", sa.Integer(), nullable=True))
|
||||
|
||||
# 2. Move content from external storage back into PostgreSQL large objects (table is still
|
||||
# called `file_record` so application code continues to work during the copy).
|
||||
try:
|
||||
_migrate_files_to_postgres()
|
||||
except Exception:
|
||||
print("Error during downgrade migration, rolling back …")
|
||||
op.drop_column("file_record", "lobj_oid")
|
||||
raise
|
||||
|
||||
# 3. After migration every row should now have `lobj_oid` populated – mark NOT NULL.
|
||||
op.alter_column("file_record", "lobj_oid", nullable=False)
|
||||
|
||||
# 4. Remove columns that are only relevant to external storage.
|
||||
op.drop_column("file_record", "updated_at")
|
||||
op.drop_column("file_record", "created_at")
|
||||
op.drop_column("file_record", "object_key")
|
||||
op.drop_column("file_record", "bucket_name")
|
||||
|
||||
# 5. Rename `file_id` back to `file_name` (still on `file_record`).
|
||||
op.alter_column("file_record", "file_id", new_column_name="file_name")
|
||||
|
||||
# 6. Finally, rename the table back to its original name expected by the legacy codebase.
|
||||
op.rename_table("file_record", "file_store")
|
||||
|
||||
print(
|
||||
"Downgrade migration completed – files are now stored inside PostgreSQL again."
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper: migrate from external storage (S3/MinIO) back into PostgreSQL large objects
|
||||
|
||||
|
||||
def _migrate_files_to_postgres() -> None:
|
||||
"""Move any files whose content lives in external S3-compatible storage back into PostgreSQL.
|
||||
|
||||
The logic mirrors *inverse* of `_migrate_files_to_external_storage` used on upgrade.
|
||||
"""
|
||||
|
||||
# Obtain DB session from Alembic context
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
# Fetch rows that have external storage pointers (bucket/object_key not NULL)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id, bucket_name, object_key FROM file_record "
|
||||
"WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
|
||||
)
|
||||
)
|
||||
|
||||
files_to_migrate = [row[0] for row in result.fetchall()]
|
||||
total_files = len(files_to_migrate)
|
||||
|
||||
if total_files == 0:
|
||||
print("No files found in external storage to migrate back to PostgreSQL.")
|
||||
return
|
||||
|
||||
print(f"Found {total_files} files to migrate back to PostgreSQL large objects.")
|
||||
|
||||
_set_tenant_contextvar(session)
|
||||
migrated_count = 0
|
||||
|
||||
# only create external store if we have files to migrate. This line
|
||||
# makes it so we need to have S3/MinIO configured to run this migration.
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Read file content from external storage (always binary)
|
||||
try:
|
||||
file_io = external_store.read_file(
|
||||
file_id=file_id, mode="b", use_tempfile=True
|
||||
)
|
||||
file_io.seek(0)
|
||||
|
||||
# Import lazily to avoid circular deps at Alembic runtime
|
||||
from onyx.db._deprecated.pg_file_store import (
|
||||
create_populate_lobj,
|
||||
) # noqa: E402
|
||||
|
||||
# Create new Postgres large object and populate it
|
||||
lobj_oid = create_populate_lobj(content=file_io, db_session=session)
|
||||
|
||||
# Update DB row: set lobj_oid, clear bucket/object_key
|
||||
session.execute(
|
||||
text(
|
||||
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, "
|
||||
"object_key = NULL WHERE file_id = :file_id"
|
||||
),
|
||||
{"lobj_oid": lobj_oid, "file_id": file_id},
|
||||
)
|
||||
except ClientError as e:
|
||||
if "NoSuchKey" in str(e):
|
||||
print(
|
||||
f"File {file_id} not found in external storage. Deleting from database."
|
||||
)
|
||||
session.execute(
|
||||
text("DELETE FROM file_record WHERE file_id = :file_id"),
|
||||
{"file_id": file_id},
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Flush the SQLAlchemy session so statements are sent to the DB, but **do not**
|
||||
# commit the transaction. The surrounding Alembic migration will commit once
|
||||
# the *entire* downgrade succeeds. This keeps the whole downgrade atomic and
|
||||
# avoids leaving the database in a partially-migrated state if a later schema
|
||||
# operation fails.
|
||||
session.flush()
|
||||
|
||||
print(
|
||||
f"Migration back to PostgreSQL completed: {migrated_count} files staged for commit."
|
||||
)
|
||||
|
||||
|
||||
def _migrate_files_to_external_storage() -> None:
|
||||
"""Migrate files from PostgreSQL large objects to external storage"""
|
||||
# Get database session
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
text(
|
||||
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL "
|
||||
"AND bucket_name IS NULL AND object_key IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
files_to_migrate = [row[0] for row in result.fetchall()]
|
||||
total_files = len(files_to_migrate)
|
||||
|
||||
if total_files == 0:
|
||||
print("No files found in PostgreSQL storage to migrate.")
|
||||
return
|
||||
|
||||
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
|
||||
|
||||
_set_tenant_contextvar(session)
|
||||
migrated_count = 0
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
|
||||
# Read file record to get metadata
|
||||
file_record = session.execute(
|
||||
text("SELECT * FROM file_record WHERE file_id = :file_id"),
|
||||
{"file_id": file_id},
|
||||
).fetchone()
|
||||
|
||||
if file_record is None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
file_content = read_lobj(
|
||||
lobj_id, db_session=session, mode="b", use_tempfile=True
|
||||
)
|
||||
except Exception as e:
|
||||
if "large object" in str(e) and "does not exist" in str(e):
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
# Handle file_metadata type conversion
|
||||
file_metadata = None
|
||||
if file_metadata is not None:
|
||||
if isinstance(file_metadata, dict):
|
||||
file_metadata = file_metadata
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
# Save to external storage (this will handle the database record update and cleanup)
|
||||
# NOTE: this WILL .commit() the transaction.
|
||||
external_store.save_file(
|
||||
file_id=file_id,
|
||||
content=file_content,
|
||||
display_name=file_record.display_name,
|
||||
file_origin=file_record.file_origin,
|
||||
file_type=file_record.file_type,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
delete_lobj_by_id(lobj_id, db_session=session)
|
||||
|
||||
migrated_count += 1
|
||||
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
|
||||
|
||||
# See note above – flush but do **not** commit so the outer Alembic transaction
|
||||
# controls atomicity.
|
||||
session.flush()
|
||||
|
||||
print(
|
||||
f"Migration completed: {migrated_count} files staged for commit to external storage."
|
||||
)
|
||||
|
||||
|
||||
def _set_tenant_contextvar(session: Session) -> None:
|
||||
"""Set the tenant contextvar to the default schema"""
|
||||
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
|
||||
print(f"Migrating files for tenant: {current_tenant}")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(current_tenant)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""kgentity_parent
|
||||
|
||||
Revision ID: cec7ec36c505
|
||||
Revises: 495cb26ce93e
|
||||
Create Date: 2025-06-07 20:07:46.400770
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cec7ec36c505"
|
||||
down_revision = "495cb26ce93e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"kg_entity",
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
# NOTE: you will have to reindex the KG after this migration as the parent_key will be null
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("kg_entity", "parent_key")
|
||||
@@ -11,7 +11,7 @@ import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.models import PublicBase
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
|
||||
@@ -77,4 +77,3 @@ def downgrade() -> None:
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -4,7 +4,10 @@ from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.external_perm import fetch_public_external_group_ids
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -15,10 +18,6 @@ from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
@@ -71,15 +70,9 @@ def _get_access_for_documents(
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
if source is None:
|
||||
logger.error(f"Document {document_id} has no source")
|
||||
continue
|
||||
|
||||
perm_sync_config = get_source_perm_sync_config(source)
|
||||
is_only_censored = (
|
||||
perm_sync_config
|
||||
and perm_sync_config.censoring_config is not None
|
||||
and perm_sync_config.doc_sync_config is None
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.task_name_builders import query_history_task_name
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
@@ -16,10 +18,11 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -34,19 +37,13 @@ logger = setup_logger()
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def export_query_history_task(self: Task, *, start: datetime, end: datetime) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
@@ -56,9 +53,12 @@ def export_query_history_task(
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=query_history_task_name(start=start, end=end),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
@@ -92,6 +92,7 @@ def export_query_history_task(
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
file_name=report_name,
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
@@ -101,7 +102,6 @@ def export_query_history_task(
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
|
||||
@@ -20,36 +20,39 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
{
|
||||
"name": "export-query-history-cleanup-task",
|
||||
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.CSV_GENERATION,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from celery import shared_task
|
||||
from ee.onyx.db.query_history import get_all_query_history_export_tasks
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
@@ -21,16 +21,20 @@ from tenacity import retry_if_exception
|
||||
from tenacity import stop_after_delay
|
||||
from tenacity import wait_random_exponential
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -48,8 +52,8 @@ from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -74,7 +78,6 @@ from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -89,24 +92,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 300 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
|
||||
try:
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
except Exception:
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
return int(base_expiration * beat_multiplier)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
@@ -120,29 +105,16 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
logger.error(f"No doc sync config found for {cc_pair.connector.source}")
|
||||
return False
|
||||
|
||||
# if indexing also does perm sync, don't start running doc_sync until at
|
||||
# least one indexing is done
|
||||
if (
|
||||
sync_config.doc_sync_config.initial_index_should_sync
|
||||
and cc_pair.last_successful_index_time is None
|
||||
):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = sync_config.doc_sync_config.doc_sync_frequency
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
if not source_sync_period:
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
@@ -214,11 +186,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(
|
||||
OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES,
|
||||
1,
|
||||
ex=_get_fence_validation_block_expiration(),
|
||||
)
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
@@ -449,7 +417,6 @@ def connector_permission_sync_generator_task(
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
cc_pair.access_type,
|
||||
db_session,
|
||||
enforce_creation=False,
|
||||
)
|
||||
@@ -465,15 +432,11 @@ def connector_permission_sync_generator_task(
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {source_type}")
|
||||
return None
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
if sync_config.censoring_config:
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -505,7 +468,6 @@ def connector_permission_sync_generator_task(
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
|
||||
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
|
||||
document_external_accesses = doc_sync_func(
|
||||
cc_pair, fetch_all_existing_docs_fn, callback
|
||||
)
|
||||
@@ -622,6 +584,91 @@ def document_update_permissions(
|
||||
return True
|
||||
|
||||
|
||||
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
|
||||
# large permissions through celery (over 1MB in size)
|
||||
# @shared_task(
|
||||
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
# time_limit=LIGHT_TIME_LIMIT,
|
||||
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
# bind=True,
|
||||
# )
|
||||
# def update_external_document_permissions_task(
|
||||
# self: Task,
|
||||
# tenant_id: str,
|
||||
# serialized_doc_external_access: dict,
|
||||
# source_string: str,
|
||||
# connector_id: int,
|
||||
# credential_id: int,
|
||||
# ) -> bool:
|
||||
# start = time.monotonic()
|
||||
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
# document_external_access = DocExternalAccess.from_dict(
|
||||
# serialized_doc_external_access
|
||||
# )
|
||||
# doc_id = document_external_access.doc_id
|
||||
# external_access = document_external_access.external_access
|
||||
|
||||
# try:
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# # Add the users to the DB if they don't exist
|
||||
# batch_add_ext_perm_user_if_not_exists(
|
||||
# db_session=db_session,
|
||||
# emails=list(external_access.external_user_emails),
|
||||
# continue_on_error=True,
|
||||
# )
|
||||
# # Then upsert the document's external permissions
|
||||
# created_new_doc = upsert_document_external_perms(
|
||||
# db_session=db_session,
|
||||
# doc_id=doc_id,
|
||||
# external_access=external_access,
|
||||
# source_type=DocumentSource(source_string),
|
||||
# )
|
||||
|
||||
# if created_new_doc:
|
||||
# # If a new document was created, we associate it with the cc_pair
|
||||
# upsert_document_by_connector_credential_pair(
|
||||
# db_session=db_session,
|
||||
# connector_id=connector_id,
|
||||
# credential_id=credential_id,
|
||||
# document_ids=[doc_id],
|
||||
# )
|
||||
|
||||
# elapsed = time.monotonic() - start
|
||||
# task_logger.info(
|
||||
# f"connector_id={connector_id} "
|
||||
# f"doc={doc_id} "
|
||||
# f"action=update_permissions "
|
||||
# f"elapsed={elapsed:.2f}"
|
||||
# )
|
||||
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
# except Exception as e:
|
||||
# error_msg = format_error_for_logging(e)
|
||||
# task_logger.warning(
|
||||
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
# )
|
||||
# task_logger.exception(
|
||||
# f"update_external_document_permissions_task exceptioned: "
|
||||
# f"connector_id={connector_id} doc_id={doc_id}"
|
||||
# )
|
||||
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
# finally:
|
||||
# task_logger.info(
|
||||
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
# )
|
||||
|
||||
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
# return False
|
||||
|
||||
# task_logger.info(
|
||||
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
# )
|
||||
# return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
|
||||
@@ -20,17 +20,15 @@ from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils imp
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.db.external_perm import mark_old_external_groups_as_stale
|
||||
from ee.onyx.db.external_perm import remove_stale_external_groups
|
||||
from ee.onyx.db.external_perm import upsert_external_groups
|
||||
from ee.onyx.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.onyx.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
get_all_cc_pair_agnostic_group_sync_sources,
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
@@ -42,8 +40,9 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -58,34 +57,19 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_EXTERNAL_GROUP_BATCH_SIZE = 100
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 300 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
|
||||
try:
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
except Exception:
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
return int(base_expiration * beat_multiplier)
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
@@ -105,20 +89,12 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
return False
|
||||
|
||||
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
|
||||
if sync_config is None:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no sync config found for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if sync_config.group_sync_config is None:
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync config found for {cc_pair.connector.source}"
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -127,7 +103,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = sync_config.group_sync_config.group_sync_frequency
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -167,8 +147,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# For some sources, we only want to sync one cc_pair per source type
|
||||
for source in get_all_cc_pair_agnostic_group_sync_sources():
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session,
|
||||
@@ -176,7 +157,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
access_type=AccessType.SYNC,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
)
|
||||
# dedupe cc_pairs to only keep the first one
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
@@ -215,11 +197,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"Exception while validating external group sync fences"
|
||||
)
|
||||
|
||||
r.set(
|
||||
OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES,
|
||||
1,
|
||||
ex=_get_fence_validation_block_expiration(),
|
||||
)
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -398,12 +376,55 @@ def connector_external_group_sync_generator_task(
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
_perform_external_group_sync(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
logger.debug(f"New external user groups: {external_user_groups}")
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -445,81 +466,6 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if sync_config.group_sync_config is None:
|
||||
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
|
||||
logger.info(
|
||||
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
mark_old_external_groups_as_stale(db_session, cc_pair_id)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_group_batch: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
external_user_group_batch.append(external_user_group)
|
||||
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
|
||||
logger.debug(
|
||||
f"New external user groups: {external_user_group_batch}"
|
||||
)
|
||||
upsert_external_groups(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
external_groups=external_user_group_batch,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
external_user_group_batch = []
|
||||
|
||||
if external_user_group_batch:
|
||||
logger.debug(f"New external user groups: {external_user_group_batch}")
|
||||
upsert_external_groups(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
external_groups=external_user_group_batch,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Removing stale external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
remove_stale_external_groups(db_session, cc_pair_id)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -53,16 +53,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# JIRA
|
||||
#####
|
||||
|
||||
# In seconds, default is 30 minutes
|
||||
JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Google Drive
|
||||
#####
|
||||
@@ -81,15 +71,6 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
#####
|
||||
# Teams
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
####
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
|
||||
|
||||
def validate_confluence_perm_sync(connector: ConfluenceConnector) -> None:
|
||||
"""
|
||||
Validate that the connector is configured correctly for permissions syncing.
|
||||
"""
|
||||
|
||||
|
||||
def validate_drive_perm_sync(connector: GoogleDriveConnector) -> None:
|
||||
"""
|
||||
Validate that the connector is configured correctly for permissions syncing.
|
||||
"""
|
||||
|
||||
|
||||
def validate_perm_sync(connector: BaseConnector) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate permissions syncing.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
if isinstance(connector, ConfluenceConnector):
|
||||
validate_confluence_perm_sync(connector)
|
||||
elif isinstance(connector, GoogleDriveConnector):
|
||||
validate_drive_perm_sync(connector)
|
||||
@@ -4,7 +4,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
@@ -63,41 +62,20 @@ def delete_public_external_group_for_cc_pair__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def mark_old_external_groups_as_stale(
|
||||
def replace_user__ext_group_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
update(User__ExternalUserGroupId)
|
||||
.where(User__ExternalUserGroupId.cc_pair_id == cc_pair_id)
|
||||
.values(stale=True)
|
||||
)
|
||||
db_session.execute(
|
||||
update(PublicExternalUserGroup)
|
||||
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
|
||||
.values(stale=True)
|
||||
)
|
||||
|
||||
|
||||
def upsert_external_groups(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
external_groups: list[ExternalUserGroup],
|
||||
group_defs: list[ExternalUserGroup],
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
Performs a true upsert operation for external user groups:
|
||||
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
|
||||
- For new groups, inserts them with stale=False
|
||||
- For public groups, uses upsert logic as well
|
||||
This function clears all existing external user group relations for a given cc_pair_id
|
||||
and replaces them with the new group definitions and commits the changes.
|
||||
"""
|
||||
# If there are no groups to add, return early
|
||||
if not external_groups:
|
||||
return
|
||||
|
||||
# collect all emails from all groups to batch add all users at once for efficiency
|
||||
all_group_member_emails = set()
|
||||
for external_group in external_groups:
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
@@ -108,17 +86,26 @@ def upsert_external_groups(
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email.lower(): user.id for user in all_group_members}
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
delete_public_external_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# Process each external group
|
||||
for external_group in external_groups:
|
||||
# map emails to ids
|
||||
email_id_map = {user.email: user.id for user in all_group_members}
|
||||
|
||||
# use these ids to create new external user group relations relating group_id to user_ids
|
||||
new_external_permissions: list[User__ExternalUserGroupId] = []
|
||||
new_public_external_groups: list[PublicExternalUserGroup] = []
|
||||
for external_group in group_defs:
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
|
||||
# Handle user-group mappings
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
@@ -127,71 +114,24 @@ def upsert_external_groups(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if the user-group mapping already exists
|
||||
existing_user_group = db_session.scalar(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id
|
||||
== external_group_id,
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_user_group:
|
||||
# Update existing record
|
||||
existing_user_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_user_group = User__ExternalUserGroupId(
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
db_session.add(new_user_group)
|
||||
|
||||
# Handle public group if needed
|
||||
if external_group.gives_anyone_access:
|
||||
# Check if the public group already exists
|
||||
existing_public_group = db_session.scalar(
|
||||
select(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.external_user_group_id == external_group_id,
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_public_group:
|
||||
# Update existing record
|
||||
existing_public_group.stale = False
|
||||
else:
|
||||
# Insert new record
|
||||
new_public_group = PublicExternalUserGroup(
|
||||
if external_group.gives_anyone_access:
|
||||
new_public_external_groups.append(
|
||||
PublicExternalUserGroup(
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
stale=False,
|
||||
)
|
||||
db_session.add(new_public_group)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_stale_external_groups(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
|
||||
User__ExternalUserGroupId.stale.is_(True),
|
||||
)
|
||||
)
|
||||
db_session.execute(
|
||||
delete(PublicExternalUserGroup).where(
|
||||
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
|
||||
PublicExternalUserGroup.stale.is_(True),
|
||||
)
|
||||
)
|
||||
db_session.add_all(new_external_permissions)
|
||||
db_session.add_all(new_public_external_groups)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -115,24 +115,11 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
|
||||
def get_usage_report_data(
|
||||
db_session: Session,
|
||||
report_display_name: str,
|
||||
report_name: str,
|
||||
) -> IO:
|
||||
"""
|
||||
Get the usage report data from the file store.
|
||||
|
||||
Args:
|
||||
db_session: The database session.
|
||||
report_display_name: The display name of the usage report. Also assumes
|
||||
that the file is stored with this as the ID in the file store.
|
||||
|
||||
Returns:
|
||||
The usage report data.
|
||||
"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
# usage report may be very large, so don't load it all into memory
|
||||
return file_store.read_file(
|
||||
file_id=report_display_name, mode="b", use_tempfile=True
|
||||
)
|
||||
return file_store.read_file(file_name=report_name, mode="b", use_tempfile=True)
|
||||
|
||||
|
||||
def write_usage_report(
|
||||
|
||||
@@ -2,6 +2,3 @@
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
|
||||
VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
@@ -4,13 +4,20 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -18,8 +25,369 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
_REQUEST_PAGINATION_LIMIT = 5000
|
||||
|
||||
CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == _VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=_REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < _REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
return space_permissions_by_space_key
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str], bool]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction = bool(read_access_user_jsons)
|
||||
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction |= bool(read_access_group_jsons)
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return (
|
||||
set(read_access_user_emails),
|
||||
set(read_access_group_names),
|
||||
found_any_restriction,
|
||||
)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page. In Confluence, a child can have
|
||||
at MOST the same level accessibility as its immediate parent.
|
||||
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
# NOTE: need the found_any_restriction, since we can find restrictions
|
||||
# but not be able to extract any user emails or group names
|
||||
# in this case, we should just give no access
|
||||
found_user_emails, found_group_names, found_any_page_level_restriction = (
|
||||
_extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
)
|
||||
# if there are individual page-level restrictions, then this is the accurate
|
||||
# restriction for the page. You cannot both have page-level restrictions AND
|
||||
# inherit restrictions from the parent.
|
||||
if found_any_page_level_restriction:
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
# ancestors seem to be in order from root to immediate parent
|
||||
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
|
||||
# we want the restrictions from the immediate parent to take precedence, so we should
|
||||
# reverse the list
|
||||
for ancestor in reversed(ancestors):
|
||||
(
|
||||
ancestor_user_emails,
|
||||
ancestor_group_names,
|
||||
found_any_restrictions_in_ancestor,
|
||||
) = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if found_any_restrictions_in_ancestor:
|
||||
# if inheriting restrictions from the parent, then the first one we run into
|
||||
# should be applied (the reason why we'd traverse more than one ancestor is if
|
||||
# the ancestor also is in "inherit" mode.)
|
||||
logger.info(
|
||||
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
|
||||
f"for document {perm_sync_data.get('id')} based on ancestor {ancestor}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=ancestor_user_emails,
|
||||
external_user_group_ids=ancestor_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# we didn't find any restrictions, so the page inherits the space's restrictions
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
logger.info(f"Found restrictions {restrictions} for document {slim_doc.id}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
|
||||
logger.warning(
|
||||
f"Individually fetching space permissions for space {space_key}. This is "
|
||||
"unexpected. It means the permissions were not able to fetched initially."
|
||||
)
|
||||
try:
|
||||
# If the space permissions are not in the cache, then fetch them
|
||||
if is_cloud:
|
||||
retrieved_space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
retrieved_space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
space_permissions_by_space_key[space_key] = retrieved_space_permissions
|
||||
space_permissions = retrieved_space_permissions
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error fetching space permissions for space {space_key}: {e}"
|
||||
)
|
||||
|
||||
if not space_permissions:
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
)
|
||||
# be safe, if we can't get the permissions then make the document inaccessible
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"Permissions are empty for document: {slim_doc.id}\n"
|
||||
"This means space permissions may be wrong for"
|
||||
f" Space key: {space_key}"
|
||||
)
|
||||
|
||||
logger.info("Finished fetching all page restrictions")
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
@@ -32,6 +400,7 @@ def confluence_doc_sync(
|
||||
Compares fetched documents against existing documents in the DB for the connector.
|
||||
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
|
||||
"""
|
||||
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -41,11 +410,62 @@ def confluence_doc_sync(
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.CONFLUENCE,
|
||||
slim_connector=confluence_connector,
|
||||
label=CONFLUENCE_DOC_SYNC_LABEL,
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
logger.info("Space permissions by space key:")
|
||||
for space_key, space_permissions in space_permissions_by_space_key.items():
|
||||
logger.info(f"Space key: {space_key}, Permissions: {space_permissions}")
|
||||
|
||||
slim_docs: list[SlimDocument] = []
|
||||
logger.info("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
# Find documents that are no longer accessible in Confluence
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
|
||||
existing_doc_ids = fetch_all_existing_docs_fn()
|
||||
|
||||
# Find missing doc IDs
|
||||
fetched_doc_ids = {doc.id for doc in slim_docs}
|
||||
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
|
||||
|
||||
# Yield access removal for missing docs. Better to be safe.
|
||||
if missing_doc_ids:
|
||||
logger.warning(
|
||||
f"Found {len(missing_doc_ids)} documents that are in the DB but "
|
||||
"not present in Confluence fetch. Making them inaccessible."
|
||||
)
|
||||
for missing_id in missing_doc_ids:
|
||||
logger.warning(f"Removing access for document ID: {missing_id}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=missing_id,
|
||||
external_access=ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info("Fetching all page restrictions for fetched documents")
|
||||
yield from _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
logger.info("Finished confluence doc sync")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -67,7 +65,7 @@ def _build_group_member_email_map(
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
) -> list[ExternalUserGroup]:
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
@@ -91,10 +89,10 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield (
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
@@ -109,4 +107,6 @@ def confluence_group_sync(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
yield all_found_group
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str], bool]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
"""
|
||||
read_access = restrictions.get("read", {})
|
||||
read_access_restrictions = read_access.get("restrictions", {})
|
||||
|
||||
# Extract the users with read access
|
||||
read_access_user = read_access_restrictions.get("user", {})
|
||||
read_access_user_jsons = read_access_user.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction = bool(read_access_user_jsons)
|
||||
|
||||
read_access_user_emails = []
|
||||
for user in read_access_user_jsons:
|
||||
# If the user has an email, then add it to the list
|
||||
if user.get("email"):
|
||||
read_access_user_emails.append(user["email"])
|
||||
# If the user has a username and not an email, then get the email from Confluence
|
||||
elif user.get("username"):
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client, user_name=user["username"]
|
||||
)
|
||||
if email:
|
||||
read_access_user_emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
read_access_group_jsons = read_access_group.get("results", [])
|
||||
# any items found means that there is a restriction
|
||||
found_any_restriction |= bool(read_access_group_jsons)
|
||||
read_access_group_names = [
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return (
|
||||
set(read_access_user_emails),
|
||||
set(read_access_group_names),
|
||||
found_any_restriction,
|
||||
)
|
||||
|
||||
|
||||
def get_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
page_id: str,
|
||||
page_restrictions: dict[str, Any],
|
||||
ancestors: list[dict[str, Any]],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page. In Confluence, a child can have
|
||||
at MOST the same level accessibility as its immediate parent.
|
||||
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
# NOTE: need the found_any_restriction, since we can find restrictions
|
||||
# but not be able to extract any user emails or group names
|
||||
# in this case, we should just give no access
|
||||
found_user_emails, found_group_names, found_any_page_level_restriction = (
|
||||
_extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=page_restrictions,
|
||||
)
|
||||
)
|
||||
# if there are individual page-level restrictions, then this is the accurate
|
||||
# restriction for the page. You cannot both have page-level restrictions AND
|
||||
# inherit restrictions from the parent.
|
||||
if found_any_page_level_restriction:
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# ancestors seem to be in order from root to immediate parent
|
||||
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
|
||||
# we want the restrictions from the immediate parent to take precedence, so we should
|
||||
# reverse the list
|
||||
for ancestor in reversed(ancestors):
|
||||
(
|
||||
ancestor_user_emails,
|
||||
ancestor_group_names,
|
||||
found_any_restrictions_in_ancestor,
|
||||
) = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if found_any_restrictions_in_ancestor:
|
||||
# if inheriting restrictions from the parent, then the first one we run into
|
||||
# should be applied (the reason why we'd traverse more than one ancestor is if
|
||||
# the ancestor also is in "inherit" mode.)
|
||||
logger.debug(
|
||||
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
|
||||
f"for document {page_id} based on ancestor {ancestor}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=ancestor_user_emails,
|
||||
external_user_group_ids=ancestor_group_names,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# we didn't find any restrictions, so the page inherits the space's restrictions
|
||||
return None
|
||||
@@ -1,165 +0,0 @@
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from ee.onyx.external_permissions.confluence.constants import REQUEST_PAGINATION_LIMIT
|
||||
from ee.onyx.external_permissions.confluence.constants import VIEWSPACE_PERMISSION_TYPE
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
if permission_category.get("type") == VIEWSPACE_PERMISSION_TYPE:
|
||||
viewspace_permissions.extend(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
if user_email:
|
||||
user_emails.add(user_email)
|
||||
else:
|
||||
logger.warning(f"Email for user {user_name} not found in Confluence")
|
||||
|
||||
if not user_emails and not group_names:
|
||||
logger.warning(
|
||||
"No user emails or group names found in Confluence space permissions"
|
||||
f"\nSpace key: {space_key}"
|
||||
f"\nSpace permissions: {space_permissions}"
|
||||
)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def _get_cloud_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions_result = confluence_client.get_space(
|
||||
space_key=space_key, expand="permissions"
|
||||
)
|
||||
space_permissions = space_permissions_result.get("permissions", [])
|
||||
|
||||
user_emails = set()
|
||||
group_names = set()
|
||||
is_externally_public = False
|
||||
for permission in space_permissions:
|
||||
subs = permission.get("subjects")
|
||||
if subs:
|
||||
# If there are subjects, then there are explicit users or groups with access
|
||||
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
|
||||
user_emails.add(email)
|
||||
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
|
||||
group_names.add(group_name)
|
||||
else:
|
||||
# If there are no subjects, then the permission is for everyone
|
||||
if permission.get("operation", {}).get(
|
||||
"operation"
|
||||
) == "read" and permission.get("anonymousAccess", False):
|
||||
# If the permission specifies read access for anonymous users, then
|
||||
# the space is publicly accessible
|
||||
is_externally_public = True
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_externally_public,
|
||||
)
|
||||
|
||||
|
||||
def get_space_permission(
|
||||
confluence_client: OnyxConfluence,
|
||||
space_key: str,
|
||||
is_cloud: bool,
|
||||
) -> ExternalAccess:
|
||||
if is_cloud:
|
||||
space_permissions = _get_cloud_space_permissions(confluence_client, space_key)
|
||||
else:
|
||||
space_permissions = _get_server_space_permissions(confluence_client, space_key)
|
||||
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
return space_permissions
|
||||
|
||||
|
||||
def get_all_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=REQUEST_PAGINATION_LIMIT
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
space_permissions = get_space_permission(confluence_client, space_key, is_cloud)
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
|
||||
return space_permissions_by_space_key
|
||||
@@ -4,6 +4,7 @@ from datetime import timezone
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -58,11 +59,17 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.external_access is None:
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=slim_doc.external_access,
|
||||
)
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
@@ -11,9 +16,10 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFuncti
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -40,34 +46,80 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def get_external_access_for_raw_gdrive_file(
|
||||
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access for a raw Google Drive file.
|
||||
def _drive_connector_creds_getter(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
|
||||
def inner() -> ServiceAccountCredentials | OAuthCredentials:
|
||||
if not google_drive_connector._creds_dict:
|
||||
raise ValueError(
|
||||
"Creds dict not found, load_credentials must be called first"
|
||||
)
|
||||
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
|
||||
return google_drive_connector.creds
|
||||
|
||||
Assumes the file we retrieved has EITHER `permissions` or `permission_ids`
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
raise ValueError("No doc_id found in file")
|
||||
return inner
|
||||
|
||||
permissions = file.get("permissions")
|
||||
permission_ids = file.get("permissionIds")
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
permissions_list: list[GoogleDrivePermission] = []
|
||||
if permissions:
|
||||
permissions_list = [
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in permissions
|
||||
]
|
||||
elif permission_ids:
|
||||
permissions_list = get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
def _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
permission_info: dict[str, Any],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
doc_id = permission_info.get("doc_id")
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
permission_ids = permission_info.get("permission_ids", [])
|
||||
if not permission_ids:
|
||||
return []
|
||||
|
||||
if not owner_email:
|
||||
logger.warning(
|
||||
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
|
||||
)
|
||||
|
||||
refreshable_drive_service = RefreshableDriveObject(
|
||||
call_stack=lambda creds: get_drive_service(
|
||||
creds=creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
),
|
||||
creds=google_drive_connector.creds,
|
||||
creds_getter=_drive_connector_creds_getter(google_drive_connector),
|
||||
)
|
||||
|
||||
return get_permissions_by_ids(
|
||||
drive_service=refreshable_drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
|
||||
def _get_permissions_from_slim_doc(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
slim_doc: SlimDocument,
|
||||
) -> ExternalAccess:
|
||||
permission_info = slim_doc.perm_sync_data or {}
|
||||
|
||||
permissions_list: list[GoogleDrivePermission] = []
|
||||
raw_permissions_list = permission_info.get("permissions", [])
|
||||
if not raw_permissions_list:
|
||||
permissions_list = _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector=google_drive_connector,
|
||||
permission_info=permission_info,
|
||||
)
|
||||
if not permissions_list:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
permissions_list = [
|
||||
GoogleDrivePermission.from_drive_permission(p) for p in raw_permissions_list
|
||||
]
|
||||
|
||||
company_domain = google_drive_connector.google_domain
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
@@ -92,7 +144,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `user` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"provided for document {slim_doc.id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
@@ -102,7 +154,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
else:
|
||||
logger.error(
|
||||
"Permission is type `group` but no email address is "
|
||||
f"provided for document {doc_id}"
|
||||
f"provided for document {slim_doc.id}"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission.type == PermissionType.DOMAIN and company_domain:
|
||||
@@ -116,6 +168,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
public = True
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = (
|
||||
group_emails
|
||||
| folder_ids_to_inherit_permissions_from
|
||||
@@ -157,13 +210,12 @@ def gdrive_doc_sync(
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
external_access=slim_doc.external_access,
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
total_processed += len(slim_doc_batch)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -101,44 +99,6 @@ def _get_all_folders(
|
||||
return all_folders
|
||||
|
||||
|
||||
def _drive_folder_to_onyx_group(
|
||||
folder: FolderInfo,
|
||||
group_email_to_member_emails_map: dict[str, list[str]],
|
||||
) -> ExternalUserGroup:
|
||||
"""
|
||||
Converts a folder into an Onyx group.
|
||||
"""
|
||||
anyone_can_access = False
|
||||
folder_member_emails: set[str] = set()
|
||||
|
||||
for permission in folder.permissions:
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address is None:
|
||||
logger.warning(
|
||||
f"User email is None for folder {folder.id} permission {permission}"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.add(permission.email_address)
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
if permission.email_address not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {permission.email_address} for folder {folder.id} "
|
||||
"not found in group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
folder_member_emails.update(
|
||||
group_email_to_member_emails_map[permission.email_address]
|
||||
)
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
anyone_can_access = True
|
||||
|
||||
return ExternalUserGroup(
|
||||
id=folder.id,
|
||||
user_emails=list(folder_member_emails),
|
||||
gives_anyone_access=anyone_can_access,
|
||||
)
|
||||
|
||||
|
||||
"""Individual Shared Drive / My Drive Permission Sync"""
|
||||
|
||||
|
||||
@@ -207,29 +167,7 @@ def _get_drive_members(
|
||||
return drive_id_to_members_map
|
||||
|
||||
|
||||
def _drive_member_map_to_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, list[str]],
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""The `user_emails` for the Shared Drive should be all individuals in the
|
||||
Shared Drive + the union of all flattened group emails."""
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
drive_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
if group_email not in group_email_to_member_emails_map:
|
||||
logger.warning(
|
||||
f"Group email {group_email} for drive {drive_id} not found in "
|
||||
"group_email_to_member_emails_map"
|
||||
)
|
||||
continue
|
||||
drive_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
yield ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(drive_member_emails),
|
||||
)
|
||||
|
||||
|
||||
def _get_all_google_groups(
|
||||
def _get_all_groups(
|
||||
admin_service: AdminService,
|
||||
google_domain: str,
|
||||
) -> set[str]:
|
||||
@@ -247,28 +185,6 @@ def _get_all_google_groups(
|
||||
return group_emails
|
||||
|
||||
|
||||
def _google_group_to_onyx_group(
|
||||
admin_service: AdminService,
|
||||
group_email: str,
|
||||
) -> ExternalUserGroup:
|
||||
"""
|
||||
This maps google group emails to their member emails.
|
||||
"""
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email),nextPageToken",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
return ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
|
||||
|
||||
def _map_group_email_to_member_emails(
|
||||
admin_service: AdminService,
|
||||
group_emails: set[str],
|
||||
@@ -366,7 +282,7 @@ def _build_onyx_groups(
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
@@ -380,27 +296,26 @@ def gdrive_group_sync(
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
|
||||
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_google_groups(
|
||||
all_group_emails = _get_all_groups(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
|
||||
# Each google group is an Onyx group, yield those
|
||||
group_email_to_member_emails_map: dict[str, list[str]] = {}
|
||||
for group_email in all_group_emails:
|
||||
onyx_group = _google_group_to_onyx_group(admin_service, group_email)
|
||||
group_email_to_member_emails_map[group_email] = onyx_group.user_emails
|
||||
yield onyx_group
|
||||
|
||||
# Each drive is a group, yield those
|
||||
for onyx_group in _drive_member_map_to_onyx_groups(
|
||||
drive_id_to_members_map, group_email_to_member_emails_map
|
||||
):
|
||||
yield onyx_group
|
||||
|
||||
# Get all folder permissions
|
||||
folder_info = _get_all_folders(
|
||||
google_drive_connector=google_drive_connector,
|
||||
skip_folders_without_permissions=True,
|
||||
)
|
||||
for folder in folder_info:
|
||||
yield _drive_folder_to_onyx_group(folder, group_email_to_member_emails_map)
|
||||
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
)
|
||||
|
||||
# Convert the maps to onyx groups
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
folder_info=folder_info,
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -2,7 +2,7 @@ from retry import retry
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -10,7 +10,7 @@ logger = setup_logger()
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def get_permissions_by_ids(
|
||||
drive_service: GoogleDriveService,
|
||||
drive_service: RefreshableDriveObject,
|
||||
doc_id: str,
|
||||
permission_ids: list[str],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
JIRA_DOC_SYNC_TAG = "jira_doc_sync"
|
||||
|
||||
|
||||
def jira_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
jira_connector = JiraConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
jira_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.JIRA,
|
||||
slim_connector=jira_connector,
|
||||
label=JIRA_DOC_SYNC_TAG,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
|
||||
Holder = dict[str, Any]
|
||||
|
||||
|
||||
class Permission(BaseModel):
|
||||
id: int
|
||||
permission: str
|
||||
holder: Holder | None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
account_id: str
|
||||
email_address: str
|
||||
display_name: str
|
||||
active: bool
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
)
|
||||
@@ -1,209 +0,0 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import PermissionScheme
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.external_permissions.jira.models import Holder
|
||||
from ee.onyx.external_permissions.jira.models import Permission
|
||||
from ee.onyx.external_permissions.jira.models import User
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
HolderMap = dict[str, list[Holder]]
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
"""
|
||||
A "Holder" in JIRA is a person / entity who "holds" the corresponding permission.
|
||||
It can have different types. They can be one of (but not limited to):
|
||||
- user (an explicitly whitelisted user)
|
||||
- projectRole (for project level "roles")
|
||||
- reporter (the reporter of an issue)
|
||||
|
||||
A "Holder" usually has following structure:
|
||||
- `{ "type": "user", "value": "$USER_ID", "user": { .. }, .. }`
|
||||
- `{ "type": "projectRole", "value": "$PROJECT_ID", .. }`
|
||||
|
||||
When we fetch the PermissionSchema from JIRA, we retrieve a list of "Holder"s.
|
||||
The list of "Holder"s can have multiple "Holder"s of the same type in the list (e.g., you can have two `"type": "user"`s in
|
||||
there, each corresponding to a different user).
|
||||
This function constructs a map of "Holder" types to a list of the "Holder"s which contained that type.
|
||||
|
||||
Returns:
|
||||
A dict from the "Holder" type to the actual "Holder" instance.
|
||||
|
||||
Example:
|
||||
```
|
||||
{
|
||||
"user": [
|
||||
{ "type": "user", "value": "10000", "user": { .. }, .. },
|
||||
{ "type": "user", "value": "10001", "user": { .. }, .. },
|
||||
],
|
||||
"projectRole": [
|
||||
{ "type": "projectRole", "value": "10010", .. },
|
||||
{ "type": "projectRole", "value": "10011", .. },
|
||||
],
|
||||
"applicationRole": [
|
||||
{ "type": "applicationRole" },
|
||||
],
|
||||
..
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
holder_map: defaultdict[str, list[Holder]] = defaultdict(list)
|
||||
|
||||
for raw_perm in permissions:
|
||||
if not hasattr(raw_perm, "raw"):
|
||||
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
|
||||
# We only care about ability to browse through projects + issues (not other permissions such as read/write).
|
||||
if permission.permission != "BROWSE_PROJECTS":
|
||||
continue
|
||||
|
||||
# In order to associate this permission to some Atlassian entity, we need the "Holder".
|
||||
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
|
||||
if not permission.holder:
|
||||
logger.warn(
|
||||
f"Expected to find a permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
type = permission.holder.get("type")
|
||||
if not type:
|
||||
logger.warn(
|
||||
f"Expected to find the type of permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
holder_map[type].append(permission.holder)
|
||||
|
||||
return holder_map
|
||||
|
||||
|
||||
def _get_user_emails(user_holders: list[Holder]) -> list[str]:
|
||||
emails = []
|
||||
|
||||
for user_holder in user_holders:
|
||||
if "user" not in user_holder:
|
||||
continue
|
||||
raw_user_dict = user_holder["user"]
|
||||
|
||||
try:
|
||||
user_model = User.model_validate(raw_user_dict)
|
||||
except ValidationError:
|
||||
logger.error(
|
||||
"Expected to be able to serialize the raw-user-dict into an instance of `User`, but validation failed;"
|
||||
f"{raw_user_dict=}"
|
||||
)
|
||||
continue
|
||||
|
||||
emails.append(user_model.email_address)
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
def _get_user_emails_from_project_roles(
|
||||
jira_client: JIRA,
|
||||
jira_project: str,
|
||||
project_role_holders: list[Holder],
|
||||
) -> list[str]:
|
||||
# NOTE (@raunakab) a `parallel_yield` may be helpful here...?
|
||||
roles = [
|
||||
jira_client.project_role(project=jira_project, id=project_role_holder["value"])
|
||||
for project_role_holder in project_role_holders
|
||||
if "value" in project_role_holder
|
||||
]
|
||||
|
||||
emails = []
|
||||
|
||||
for role in roles:
|
||||
if not hasattr(role, "actors"):
|
||||
continue
|
||||
|
||||
for actor in role.actors:
|
||||
if not hasattr(actor, "actorUser") or not hasattr(
|
||||
actor.actorUser, "accountId"
|
||||
):
|
||||
continue
|
||||
|
||||
user = jira_client.user(id=actor.actorUser.accountId)
|
||||
if not hasattr(user, "accountType") or user.accountType != "atlassian":
|
||||
continue
|
||||
|
||||
if not hasattr(user, "emailAddress"):
|
||||
msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}"
|
||||
if hasattr(user, "displayName"):
|
||||
msg += f" {actor.displayName=}"
|
||||
logger.warn(msg)
|
||||
continue
|
||||
|
||||
emails.append(user.emailAddress)
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
def _build_external_access_from_holder_map(
|
||||
jira_client: JIRA, jira_project: str, holder_map: HolderMap
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
# Note:
|
||||
If the `holder_map` contains an instance of "anyone", then this is a public JIRA project.
|
||||
Otherwise, we fetch the "projectRole"s (i.e., the user-groups in JIRA speak), and the user emails.
|
||||
"""
|
||||
|
||||
if "anyone" in holder_map:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(), external_user_group_ids=set(), is_public=True
|
||||
)
|
||||
|
||||
user_emails = (
|
||||
_get_user_emails(user_holders=holder_map["user"])
|
||||
if "user" in holder_map
|
||||
else []
|
||||
)
|
||||
project_role_user_emails = (
|
||||
_get_user_emails_from_project_roles(
|
||||
jira_client=jira_client,
|
||||
jira_project=jira_project,
|
||||
project_role_holders=holder_map["projectRole"],
|
||||
)
|
||||
if "projectRole" in holder_map
|
||||
else []
|
||||
)
|
||||
|
||||
external_user_emails = set(user_emails + project_role_user_emails)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=external_user_emails,
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_project_permissions(
|
||||
jira_client: JIRA,
|
||||
jira_project: str,
|
||||
) -> ExternalAccess | None:
|
||||
project_permissions: PermissionScheme = jira_client.project_permissionscheme(
|
||||
project=jira_project
|
||||
)
|
||||
|
||||
if not hasattr(project_permissions, "permissions"):
|
||||
return None
|
||||
|
||||
if not isinstance(project_permissions.permissions, list):
|
||||
return None
|
||||
|
||||
holder_map = _build_holder_map(permissions=project_permissions.permissions)
|
||||
|
||||
return _build_external_access_from_holder_map(
|
||||
jira_client=jira_client, jira_project=jira_project, holder_map=holder_map
|
||||
)
|
||||
@@ -4,8 +4,6 @@ from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
@@ -39,11 +37,8 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str, # tenant_id
|
||||
"ConnectorCredentialPair", # cc_pair
|
||||
str,
|
||||
"ConnectorCredentialPair",
|
||||
],
|
||||
Generator["ExternalUserGroup", None, None],
|
||||
list["ExternalUserGroup"],
|
||||
]
|
||||
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
CensoringFuncType = Callable[[list[InferenceChunk], str], list[InferenceChunk]]
|
||||
|
||||
@@ -1,33 +1,43 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_sources
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source has a censoring config.
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
all_censoring_enabled_sources = get_all_censoring_enabled_sources()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in all_censoring_enabled_sources
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
}
|
||||
|
||||
|
||||
@@ -60,11 +70,7 @@ def _post_query_chunk_censoring(
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
sync_config = get_source_perm_sync_config(source)
|
||||
if sync_config is None or sync_config.censoring_config is None:
|
||||
raise ValueError(f"No sync config found for {source}")
|
||||
|
||||
censor_chunks_for_source = sync_config.censoring_config.chunk_censoring_func
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,7 +10,7 @@ from ee.onyx.external_permissions.salesforce.utils import (
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -44,7 +44,7 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries for this source are essentially instant
|
||||
first_doc_id = chunks[0].document_id
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
salesforce_client = get_any_salesforce_client_for_doc_id(
|
||||
db_session, first_doc_id
|
||||
)
|
||||
@@ -217,7 +217,7 @@ def censor_salesforce_chunks(
|
||||
def _get_objects_access_for_user_email(
|
||||
object_ids: set[str], user_email: str
|
||||
) -> dict[str, bool]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
external_groups = fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session=db_session,
|
||||
user_email=user_email,
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.slack.connector import ChannelType
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def get_channel_access(
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get channel access permissions for a Slack channel.
|
||||
|
||||
Args:
|
||||
client: Slack WebClient instance
|
||||
channel: Slack channel object containing channel info
|
||||
user_cache: Cache of user IDs to BasicExpertInfo objects. May be updated in place.
|
||||
|
||||
Returns:
|
||||
ExternalAccess object for the channel.
|
||||
"""
|
||||
channel_is_public = not channel["is_private"]
|
||||
if channel_is_public:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
channel_id = channel["id"]
|
||||
|
||||
# Get all member IDs for the channel
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
member_ids.extend(result.get("members", []))
|
||||
|
||||
member_emails = set()
|
||||
for member_id in member_ids:
|
||||
# Try to get user info from cache or fetch it
|
||||
user_info = expert_info_from_slack_id(
|
||||
user_id=member_id,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
# If we have user info and an email, add it to the set
|
||||
if user_info and user_info.email:
|
||||
member_emails.add(user_info.email)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=member_emails,
|
||||
# NOTE: groups are not used, since adding a group to a channel just adds all
|
||||
# users that are in the group.
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
@@ -108,15 +108,11 @@ def _get_slack_document_access(
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
"Please check to make sure that your Slack bot token has the "
|
||||
"`channels:read` scope"
|
||||
)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
yield DocExternalAccess(
|
||||
external_access=doc_metadata.external_access,
|
||||
external_access=channel_permissions[channel_id],
|
||||
doc_id=doc_metadata.id,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,8 +58,6 @@ def slack_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
"""NOTE: not used atm. All channel access is done at the
|
||||
individual user level. Leaving in for now in case we need it later."""
|
||||
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -1,207 +1,83 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
|
||||
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
|
||||
from ee.onyx.external_permissions.slack.group_sync import slack_group_sync
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
|
||||
|
||||
class DocSyncConfig(BaseModel):
|
||||
doc_sync_frequency: int
|
||||
doc_sync_func: DocSyncFuncType
|
||||
initial_index_should_sync: bool
|
||||
|
||||
|
||||
class GroupSyncConfig(BaseModel):
|
||||
group_sync_frequency: int
|
||||
group_sync_func: GroupSyncFuncType
|
||||
group_sync_is_cc_pair_agnostic: bool
|
||||
|
||||
|
||||
class CensoringConfig(BaseModel):
|
||||
chunk_censoring_func: CensoringFuncType
|
||||
|
||||
|
||||
class SyncConfig(BaseModel):
|
||||
# None means we don't perform a doc_sync
|
||||
doc_sync_config: DocSyncConfig | None = None
|
||||
# None means we don't perform a group_sync
|
||||
group_sync_config: GroupSyncConfig | None = None
|
||||
# None means we don't perform a chunk_censoring
|
||||
censoring_config: CensoringConfig | None = None
|
||||
|
||||
|
||||
# Mock doc sync function for testing (no-op)
|
||||
def mock_doc_sync(
|
||||
cc_pair: "ConnectorCredentialPair",
|
||||
fetch_all_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: Optional["IndexingHeartbeatInterface"],
|
||||
) -> Generator["DocExternalAccess", None, None]:
|
||||
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
|
||||
yield from []
|
||||
|
||||
|
||||
_SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
DocumentSource.GOOGLE_DRIVE: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=gdrive_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=gdrive_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.CONFLUENCE: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=confluence_doc_sync,
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=confluence_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=True,
|
||||
),
|
||||
),
|
||||
DocumentSource.JIRA: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=JIRA_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=jira_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
# Groups are not needed for Slack.
|
||||
# All channel access is done at the individual user level.
|
||||
DocumentSource.SLACK: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=slack_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
DocumentSource.GMAIL: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=gmail_doc_sync,
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.SALESFORCE: SyncConfig(
|
||||
censoring_config=CensoringConfig(
|
||||
chunk_censoring_func=censor_salesforce_chunks,
|
||||
),
|
||||
),
|
||||
DocumentSource.MOCK_CONNECTOR: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=mock_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
# Groups are not needed for Teams.
|
||||
# All channel access is done at the individual user level.
|
||||
DocumentSource.TEAMS: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=TEAMS_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=teams_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
# These functions update:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
DocumentSource.GMAIL: gmail_doc_sync,
|
||||
}
|
||||
|
||||
|
||||
def source_requires_doc_sync(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires doc syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
return _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config is not None
|
||||
return source in DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
DocumentSource.SLACK: slack_group_sync,
|
||||
}
|
||||
|
||||
|
||||
def source_requires_external_group_sync(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires external group syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
return _SOURCE_TO_SYNC_CONFIG[source].group_sync_config is not None
|
||||
return source in GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
|
||||
def get_source_perm_sync_config(source: DocumentSource) -> SyncConfig | None:
|
||||
"""Returns the frequency of the external group sync for the given DocumentSource."""
|
||||
return _SOURCE_TO_SYNC_CONFIG.get(source)
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DocumentSource.CONFLUENCE,
|
||||
}
|
||||
|
||||
|
||||
def source_group_sync_is_cc_pair_agnostic(source: DocumentSource) -> bool:
|
||||
"""Checks if the given DocumentSource requires external group syncing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
|
||||
group_sync_config = _SOURCE_TO_SYNC_CONFIG[source].group_sync_config
|
||||
if group_sync_config is None:
|
||||
return False
|
||||
|
||||
return group_sync_config.group_sync_is_cc_pair_agnostic
|
||||
return source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
|
||||
|
||||
def get_all_cc_pair_agnostic_group_sync_sources() -> set[DocumentSource]:
|
||||
"""Returns the set of sources that have external group syncing that is cc_pair agnostic."""
|
||||
return {
|
||||
source
|
||||
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
|
||||
if sync_config.group_sync_config is not None
|
||||
and sync_config.group_sync_config.group_sync_is_cc_pair_agnostic
|
||||
}
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in _SOURCE_TO_SYNC_CONFIG
|
||||
|
||||
|
||||
def get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""Returns the set of sources that have censoring enabled."""
|
||||
return {
|
||||
source
|
||||
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
|
||||
if sync_config.censoring_config is not None
|
||||
}
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
"""Returns True if the given DocumentSource requires permissions to be fetched during indexing."""
|
||||
if source not in _SOURCE_TO_SYNC_CONFIG:
|
||||
return False
|
||||
|
||||
doc_sync_config = _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config
|
||||
if doc_sync_config is None:
|
||||
return False
|
||||
|
||||
return doc_sync_config.initial_index_should_sync
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
|
||||
|
||||
|
||||
def teams_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
teams_connector = TeamsConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
teams_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.TEAMS,
|
||||
slim_connector=teams_connector,
|
||||
label=TEAMS_DOC_SYNC_LABEL,
|
||||
)
|
||||
@@ -1,83 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
label: str,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
A convenience function for performing a generic document synchronization.
|
||||
|
||||
Notes:
|
||||
A generic doc sync includes:
|
||||
- fetching existing docs
|
||||
- fetching *all* new (slim) docs
|
||||
- yielding external-access permissions for existing docs which do not exist in the newly fetched slim-docs set (with their
|
||||
`external_access` set to "private")
|
||||
- yielding external-access permissions for newly fetched docs
|
||||
|
||||
Returns:
|
||||
A `Generator` which yields existing and newly fetched external-access permissions.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
|
||||
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(f"{label}: Stop signal detected")
|
||||
callback.progress(label, 1)
|
||||
|
||||
for doc in doc_batch:
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(
|
||||
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
|
||||
)
|
||||
|
||||
newly_fetched_doc_ids.add(doc.id)
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=doc.id,
|
||||
external_access=doc.external_access,
|
||||
)
|
||||
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
|
||||
existing_doc_ids = set(fetch_all_existing_docs_fn())
|
||||
|
||||
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
|
||||
|
||||
if not missing_doc_ids:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
f"Found {len(missing_doc_ids)=} documents that are in the DB but not present in fetch. Making them inaccessible."
|
||||
)
|
||||
|
||||
for missing_id in missing_doc_ids:
|
||||
logger.warning(f"Removing access for {missing_id=}")
|
||||
yield DocExternalAccess(
|
||||
doc_id=missing_id,
|
||||
external_access=ExternalAccess.empty(),
|
||||
)
|
||||
|
||||
logger.info(f"Finished {doc_source} doc sync")
|
||||
@@ -19,7 +19,7 @@ from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
router = APIRouter(prefix="/analytics")
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
@@ -26,9 +26,9 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -142,12 +142,11 @@ def put_logo(
|
||||
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
logger.exception("Faield to fetch logo file")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No logo file found",
|
||||
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
|
||||
@@ -131,11 +131,11 @@ def upload_logo(
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
|
||||
content=content,
|
||||
display_name=display_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type=file_type,
|
||||
file_id=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from ee.onyx.db.standard_answer import remove_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer
|
||||
from ee.onyx.db.standard_answer import update_standard_answer_category
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.manage.models import StandardAnswerCategory
|
||||
|
||||
@@ -11,7 +11,7 @@ from ee.onyx.auth.users import decode_anonymous_user_jwt_token
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -12,10 +12,10 @@ from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -25,12 +25,12 @@ from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -33,11 +33,11 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
|
||||
@@ -17,11 +17,11 @@ from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
|
||||
@@ -40,7 +40,7 @@ from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
|
||||
@@ -31,7 +31,7 @@ from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.context.search.utils import drop_llm_indices
|
||||
from onyx.context.search.utils import relevant_sections_to_indices
|
||||
from onyx.db.chat import get_prompt_by_id
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -12,7 +11,6 @@ from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.background.task_name_builders import query_history_task_name
|
||||
from ee.onyx.db.query_history import get_all_query_history_export_tasks
|
||||
from ee.onyx.db.query_history import get_page_of_chat_sessions
|
||||
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
||||
@@ -37,19 +35,17 @@ from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.file_record import get_query_history_export_files
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pg_file_store import get_query_history_export_files
|
||||
from onyx.db.tasks import get_task_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -314,32 +310,17 @@ def start_query_history_export(
|
||||
f"Start time must come before end time, but instead got the start time coming after; {start=} {end=}",
|
||||
)
|
||||
|
||||
task_id_uuid = uuid.uuid4()
|
||||
task_id = str(task_id_uuid)
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=query_history_task_name(start=start, end=end),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
client_app.send_task(
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
task_id=task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
queue=OnyxCeleryQueues.CSV_GENERATION,
|
||||
kwargs={
|
||||
"start": start,
|
||||
"end": end,
|
||||
"start_time": start_time,
|
||||
"tenant_id": get_current_tenant_id(),
|
||||
},
|
||||
)
|
||||
|
||||
return {"request_id": task_id}
|
||||
return {"request_id": task.id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status")
|
||||
@@ -362,7 +343,7 @@ def get_query_history_export_status(
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
has_file = file_store.has_file(
|
||||
file_id=report_name,
|
||||
file_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
)
|
||||
@@ -387,7 +368,7 @@ def download_query_history_csv(
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
has_file = file_store.has_file(
|
||||
file_id=report_name,
|
||||
file_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import PGFileStore
|
||||
from onyx.db.models import TaskQueueState
|
||||
|
||||
|
||||
@@ -254,7 +254,7 @@ class QueryHistoryExport(BaseModel):
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
file: FileRecord,
|
||||
file: PGFileStore,
|
||||
) -> "QueryHistoryExport":
|
||||
if not file.file_metadata or not isinstance(file.file_metadata, dict):
|
||||
raise RuntimeError(
|
||||
@@ -262,7 +262,7 @@ class QueryHistoryExport(BaseModel):
|
||||
)
|
||||
|
||||
metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata))
|
||||
task_id = extract_task_id_from_query_history_report_name(file.file_id)
|
||||
task_id = extract_task_id_from_query_history_report_name(file.file_name)
|
||||
|
||||
return cls(
|
||||
task_id=task_id,
|
||||
|
||||
@@ -14,7 +14,7 @@ from ee.onyx.db.usage_export import get_usage_report_data
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.constants import STANDARD_CHUNK_SIZE
|
||||
|
||||
|
||||
@@ -62,16 +62,17 @@ def generate_chat_messages_report(
|
||||
]
|
||||
)
|
||||
|
||||
# after writing seek to beginning of buffer
|
||||
# after writing seek to begining of buffer
|
||||
temp_file.seek(0)
|
||||
file_id = file_store.save_file(
|
||||
file_store.save_file(
|
||||
file_name=file_name,
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
return file_id
|
||||
return file_name
|
||||
|
||||
|
||||
def generate_user_report(
|
||||
@@ -96,14 +97,15 @@ def generate_user_report(
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
|
||||
|
||||
temp_file.seek(0)
|
||||
file_id = file_store.save_file(
|
||||
file_store.save_file(
|
||||
file_name=file_name,
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
return file_id
|
||||
return file_name
|
||||
|
||||
|
||||
def create_new_usage_report(
|
||||
@@ -114,16 +116,16 @@ def create_new_usage_report(
|
||||
report_id = str(uuid.uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
messages_file_id = generate_chat_messages_report(
|
||||
messages_filename = generate_chat_messages_report(
|
||||
db_session, file_store, report_id, period
|
||||
)
|
||||
users_file_id = generate_user_report(db_session, file_store, report_id)
|
||||
users_filename = generate_user_report(db_session, file_store, report_id)
|
||||
|
||||
with tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) as zip_buffer:
|
||||
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# write messages
|
||||
chat_messages_tmpfile = file_store.read_file(
|
||||
messages_file_id, mode="b", use_tempfile=True
|
||||
messages_filename, mode="b", use_tempfile=True
|
||||
)
|
||||
zip_file.writestr(
|
||||
"chat_messages.csv",
|
||||
@@ -132,7 +134,7 @@ def create_new_usage_report(
|
||||
|
||||
# write users
|
||||
users_tmpfile = file_store.read_file(
|
||||
users_file_id, mode="b", use_tempfile=True
|
||||
users_filename, mode="b", use_tempfile=True
|
||||
)
|
||||
zip_file.writestr("users.csv", users_tmpfile.read())
|
||||
|
||||
@@ -144,11 +146,11 @@ def create_new_usage_report(
|
||||
f"_{report_id}_usage_report.zip"
|
||||
)
|
||||
file_store.save_file(
|
||||
file_name=report_name,
|
||||
content=zip_buffer,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="application/zip",
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
# add report after zip file is written
|
||||
|
||||
@@ -27,9 +27,9 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_context_manager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
@@ -235,7 +235,7 @@ def seed_db() -> None:
|
||||
logger.debug("No seeding configuration file passed")
|
||||
return
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
if seed_config.llms is not None:
|
||||
_seed_llms(db_session, seed_config.llms)
|
||||
if seed_config.personas is not None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
|
||||
@@ -8,8 +8,8 @@ from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +34,7 @@ def run_alembic_migrations(schema_name: str) -> None:
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
@@ -9,7 +9,7 @@ from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
|
||||
@@ -5,8 +5,8 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import write_pending_users
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -9,7 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
|
||||
@@ -16,7 +16,7 @@ from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
@@ -112,15 +107,3 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
_source_should_fetch_permissions_during_indexing_func = cast(
|
||||
Callable[[DocumentSource], bool],
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.sync_params",
|
||||
"source_should_fetch_permissions_during_indexing",
|
||||
False,
|
||||
),
|
||||
)
|
||||
return _source_should_fetch_permissions_during_indexing_func(source)
|
||||
|
||||
@@ -40,30 +40,6 @@ class ExternalAccess:
|
||||
def num_entries(self) -> int:
|
||||
return len(self.external_user_emails) + len(self.external_user_group_ids)
|
||||
|
||||
@classmethod
|
||||
def public(cls) -> "ExternalAccess":
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "ExternalAccess":
|
||||
"""
|
||||
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||
|
||||
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
||||
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
|
||||
@@ -78,7 +78,7 @@ def should_continue(state: BasicState) -> str:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@@ -87,7 +87,7 @@ if __name__ == "__main__":
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import cast
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
|
||||
@@ -111,7 +111,7 @@ def answer_query_graph_builder() -> StateGraph:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
@@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
@@ -238,7 +238,7 @@ def agent_search_graph_builder() -> StateGraph:
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
@@ -246,7 +246,7 @@ if __name__ == "__main__":
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
graph_config = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
|
||||
@@ -109,7 +109,7 @@ def answer_refined_query_graph_builder() -> StateGraph:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
@@ -119,7 +119,7 @@ if __name__ == "__main__":
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = SubQuestionAnsweringInput(
|
||||
question="what can you do with onyx?",
|
||||
question_id="0_0",
|
||||
|
||||
@@ -131,7 +131,7 @@ def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
@@ -142,7 +142,7 @@ if __name__ == "__main__":
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -60,7 +60,7 @@ def rerank_documents(
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
@@ -67,7 +67,7 @@ def retrieve_documents(
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
|
||||
@@ -6,13 +6,8 @@ from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class KGAnalysisPath(str, Enum):
|
||||
@@ -49,17 +44,6 @@ def research_individual_object(
|
||||
and state.strategy == KGAnswerStrategy.DEEP
|
||||
):
|
||||
|
||||
if state.source_filters and state.source_division:
|
||||
segments = state.source_filters
|
||||
segment_type = KGSourceDivisionType.SOURCE.value
|
||||
else:
|
||||
segments = state.div_con_entities
|
||||
segment_type = KGSourceDivisionType.ENTITY.value
|
||||
|
||||
if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
|
||||
logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
|
||||
return "filtered_search"
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
@@ -70,14 +54,13 @@ def research_individual_object(
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
source_entity_filters=state.source_filters,
|
||||
segment_type=segment_type,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
step_results=[],
|
||||
),
|
||||
)
|
||||
for research_nr, entity in enumerate(segments)
|
||||
for research_nr, entity in enumerate(state.div_con_entities)
|
||||
]
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
return "filtered_search"
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.document import get_kg_doc_info_for_entity_name
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.db.entities import get_entity_name
|
||||
from onyx.db.entity_type import get_entity_types
|
||||
@@ -217,6 +217,31 @@ def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.MAIN_ANSWER,
|
||||
level=level,
|
||||
level_question_num=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_main_answer_token(
|
||||
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=token, # No need to add space as tokenizer handles this
|
||||
level=level,
|
||||
level_question_num=level_question_num,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get document information for an entity, including its semantic name and document details.
|
||||
@@ -267,9 +292,11 @@ def rename_entities_in_answer(answer: str) -> str:
|
||||
|
||||
# Clean up any spaces around ::
|
||||
answer = re.sub(r"::\s+", "::", answer)
|
||||
logger.debug(f"After cleaning spaces: {answer}")
|
||||
|
||||
# Pattern to match entity_type::entity_name, with optional quotes
|
||||
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
|
||||
logger.debug(f"Using pattern: {pattern}")
|
||||
|
||||
matches = list(re.finditer(pattern, answer))
|
||||
logger.debug(f"Found {len(matches)} matches")
|
||||
@@ -288,8 +315,10 @@ def rename_entities_in_answer(answer: str) -> str:
|
||||
entity_type = match.group(1).upper().strip()
|
||||
entity_name = match.group(2).strip()
|
||||
potential_entity_id_name = make_entity_id(entity_type, entity_name)
|
||||
logger.debug(f"Processing entity: {potential_entity_id_name}")
|
||||
|
||||
if entity_type not in active_entity_types:
|
||||
logger.debug(f"Entity type {entity_type} not in active types")
|
||||
continue
|
||||
|
||||
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
|
||||
@@ -299,8 +328,14 @@ def rename_entities_in_answer(answer: str) -> str:
|
||||
processed_refs[match.group(0)] = (
|
||||
replacement_candidate.semantic_linked_entity_name
|
||||
)
|
||||
logger.debug(
|
||||
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_linked_entity_name}"
|
||||
)
|
||||
else:
|
||||
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
|
||||
logger.debug(
|
||||
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_entity_name}"
|
||||
)
|
||||
|
||||
# Replace all references in the answer
|
||||
for ref, replacement in processed_refs.items():
|
||||
|
||||
@@ -2,26 +2,19 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGViewNames(BaseModel):
|
||||
allowed_docs_view_name: str
|
||||
kg_relationships_view_name: str
|
||||
kg_entity_view_name: str
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
search_type: KGSearchType
|
||||
search_strategy: KGAnswerStrategy
|
||||
relationship_detection: KGRelationshipDetection
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
@@ -34,6 +27,7 @@ class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import cast
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
@@ -25,17 +24,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import create_views
|
||||
from onyx.db.kg_temp_view import get_user_view_names
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_entity_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -81,16 +79,13 @@ def extract_ert(
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
tenant_id = get_current_tenant_id()
|
||||
kg_views = get_user_view_names(user_email, tenant_id)
|
||||
allowed_docs_view_name, kg_relationships_view_name = get_user_view_names(user_email)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_views(
|
||||
db_session,
|
||||
tenant_id=tenant_id,
|
||||
user_email=user_email,
|
||||
allowed_docs_view_name=kg_views.allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_views.kg_relationships_view_name,
|
||||
kg_entity_view_name=kg_views.kg_entity_view_name,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
### get the entities, terms, and filters
|
||||
@@ -136,18 +131,25 @@ def extract_ert(
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValidationError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity Extraction")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[], time_filter=""
|
||||
)
|
||||
try:
|
||||
entity_extraction_result = (
|
||||
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[], time_filter=""
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
|
||||
# remove the attribute filters from the entities to for the purpose of the relationship
|
||||
@@ -214,9 +216,9 @@ def extract_ert(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValidationError:
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Relationship Extraction"
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
@@ -251,10 +253,10 @@ Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
extracted_entities_no_attributes=entities_no_attributes,
|
||||
extracted_relationships=relationship_extraction_result.relationships,
|
||||
extracted_terms=entity_extraction_result.terms,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
|
||||
kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
|
||||
kg_entity_temp_view_name=kg_views.kg_entity_view_name,
|
||||
kg_doc_temp_view_name=allowed_docs_view_name,
|
||||
kg_rel_temp_view_name=kg_relationships_view_name,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
@@ -27,10 +26,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.clustering.normalizations import normalize_terms
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -146,6 +147,7 @@ def analyze(
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
relationships = state.extracted_relationships
|
||||
terms = state.extracted_terms
|
||||
time_filter = state.time_filter
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
@@ -155,14 +157,18 @@ def analyze(
|
||||
# Continue with node
|
||||
|
||||
normalized_entities = normalize_entities(
|
||||
entities,
|
||||
entities, allowed_docs_temp_view_name=state.kg_doc_temp_view_name
|
||||
)
|
||||
|
||||
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
|
||||
state.extracted_entities_w_attributes,
|
||||
allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
|
||||
normalized_entities.entity_normalization_map,
|
||||
)
|
||||
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_terms = normalize_terms(terms)
|
||||
normalized_time_filter = time_filter
|
||||
|
||||
# If single-doc inquiry, send to single-doc processing directly
|
||||
@@ -231,9 +237,6 @@ def analyze(
|
||||
)
|
||||
search_type = approach_extraction_result.search_type
|
||||
search_strategy = approach_extraction_result.search_strategy
|
||||
relationship_detection = (
|
||||
approach_extraction_result.relationship_detection.value
|
||||
)
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
@@ -243,7 +246,6 @@ def analyze(
|
||||
)
|
||||
search_type = KGSearchType.SEARCH
|
||||
search_strategy = KGAnswerStrategy.DEEP
|
||||
relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
@@ -264,19 +266,6 @@ def analyze(
|
||||
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
|
||||
Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
|
||||
extraction_detected_relationships = len(query_graph_relationships) > 0
|
||||
if extraction_detected_relationships:
|
||||
query_type = KGRelationshipDetection.RELATIONSHIPS.value
|
||||
|
||||
if extraction_detected_relationships:
|
||||
logger.warning(
|
||||
"Fyi - Extraction detected relationships: "
|
||||
f"{extraction_detected_relationships}, "
|
||||
f"but relationship detection: {relationship_detection}"
|
||||
)
|
||||
else:
|
||||
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
@@ -289,9 +278,9 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
query_graph_entities_no_attributes=query_graph_entities,
|
||||
query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
|
||||
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=[], # TODO: remove fully later
|
||||
normalized_terms=normalized_terms.terms,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=search_strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
@@ -299,7 +288,6 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
single_doc_id=single_doc_id,
|
||||
search_type=search_type,
|
||||
query_type=query_type,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
|
||||
@@ -22,18 +21,13 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from onyx.configs.kg_configs import KG_MAX_DEEP_SEARCH_RESULTS
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_MAX_TOKENS
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT_OVERRIDE
|
||||
from onyx.configs.kg_configs import KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX
|
||||
from onyx.configs.kg_configs import KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX
|
||||
from onyx.configs.kg_configs import KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_tenant
|
||||
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import drop_views
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
@@ -44,65 +38,16 @@ from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _raise_error_if_sql_fails_problem_test(
|
||||
sql_statement: str, relationship_view_name: str, entity_view_name: str | None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the SQL statement is valid.
|
||||
"""
|
||||
|
||||
if entity_view_name is None:
|
||||
raise ValueError("entity_view_name is not set for sql_statement")
|
||||
|
||||
authorized_user_specification = relationship_view_name[
|
||||
len(KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX) :
|
||||
]
|
||||
|
||||
# remove the proper relationship and entity view names
|
||||
base_sql_statement = sql_statement.replace(relationship_view_name, " ").replace(
|
||||
entity_view_name, " "
|
||||
)
|
||||
|
||||
# check whether other non-authorized relationship or entity viewnames are in sql_statement
|
||||
if any(
|
||||
view_name in base_sql_statement
|
||||
for view_name in [
|
||||
KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX,
|
||||
KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX,
|
||||
KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX,
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"SQL statement would attempt to access unauthorized views: {sql_statement} for \
|
||||
user specification {authorized_user_specification}"
|
||||
def _drop_temp_views(
|
||||
allowed_docs_view_name: str, kg_relationships_view_name: str
|
||||
) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
drop_views(
|
||||
db_session,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
# check whether the sql statement would attempt to do unauthorized operations
|
||||
# (the restrivtive db priviledges would preclude this anyway, but we check for safety and reporting)
|
||||
UNAUTHORIZED_SQL_OPERATIONS = [
|
||||
"INSERT ",
|
||||
"UPDATE ",
|
||||
"DELETE ",
|
||||
"CREATE ",
|
||||
"DROP ",
|
||||
"ALTER ",
|
||||
"TRUNCATE ",
|
||||
"RENAME ",
|
||||
"GRANT ",
|
||||
"REVOKE ",
|
||||
"DENY ",
|
||||
]
|
||||
|
||||
if any(
|
||||
operation.upper() in base_sql_statement.upper()
|
||||
for operation in UNAUTHORIZED_SQL_OPERATIONS
|
||||
):
|
||||
raise ValueError(
|
||||
f"SQL statement would attempt to do unauthorized operations: {sql_statement}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> str:
|
||||
"""
|
||||
@@ -118,14 +63,13 @@ def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> s
|
||||
def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
return any(
|
||||
agg_func in sql_statement.upper()
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM(", "GROUP BY"]
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
|
||||
)
|
||||
|
||||
|
||||
def _get_source_documents(
|
||||
sql_statement: str,
|
||||
llm: LLM,
|
||||
focus: str | None,
|
||||
allowed_docs_view_name: str,
|
||||
kg_relationships_view_name: str,
|
||||
) -> str | None:
|
||||
@@ -133,13 +77,7 @@ def _get_source_documents(
|
||||
Generate SQL to retrieve source documents based on the input sql statement.
|
||||
"""
|
||||
|
||||
base_prompt = (
|
||||
SOURCE_DETECTION_PROMPT
|
||||
if (focus == KGRelationshipDetection.RELATIONSHIPS.value or focus is None)
|
||||
else ENTITY_SOURCE_DETECTION_PROMPT
|
||||
)
|
||||
|
||||
source_detection_prompt = base_prompt.replace(
|
||||
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
|
||||
"---original_sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
@@ -155,8 +93,8 @@ def _get_source_documents(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
|
||||
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
|
||||
timeout_override=25,
|
||||
max_tokens=1200,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
@@ -166,11 +104,12 @@ def _get_source_documents(
|
||||
sql_statement = sql_statement.split("</sql>")[0].strip()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Could not generate source documents SQL: {e}"
|
||||
if cleaned_response:
|
||||
error_msg += f". Original model response: {cleaned_response}"
|
||||
|
||||
logger.error(error_msg)
|
||||
if cleaned_response is not None:
|
||||
logger.error(
|
||||
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Could not generate source documents SQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -200,9 +139,6 @@ def generate_simple_sql(
|
||||
if state.kg_rel_temp_view_name is None:
|
||||
raise ValueError("kg_rel_temp_view_name is not set")
|
||||
|
||||
if state.kg_entity_temp_view_name is None:
|
||||
raise ValueError("kg_entity_temp_view_name is not set")
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
@@ -251,41 +187,30 @@ def generate_simple_sql(
|
||||
# First, create string of contextualized entities to avoid the model not
|
||||
# being aware of what eg ACCOUNT::SF_8254Hs means as a normalized entity
|
||||
|
||||
# TODO: restructure with broader node rework
|
||||
|
||||
entity_explanation_str = _build_entity_explanation_str(
|
||||
state.entity_normalization_map
|
||||
)
|
||||
|
||||
doc_temp_view = state.kg_doc_temp_view_name
|
||||
rel_temp_view = state.kg_rel_temp_view_name
|
||||
ent_temp_view = state.kg_entity_temp_view_name
|
||||
current_tenant = get_current_tenant_id()
|
||||
|
||||
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_ENTITY_SQL_PROMPT.replace(
|
||||
"---entity_types---", entities_types_str
|
||||
)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
current_tenant_view_name = f'"{current_tenant}".{state.kg_doc_temp_view_name}'
|
||||
current_tenant_rel_view_name = f'"{current_tenant}".{state.kg_rel_temp_view_name}'
|
||||
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
else:
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
.replace(
|
||||
"---query_relationships---",
|
||||
"\n".join(state.query_graph_relationships),
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
.replace(
|
||||
"---query_relationships---", "\n".join(state.query_graph_relationships)
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
# prepare SQL query generation
|
||||
|
||||
@@ -302,8 +227,8 @@ def generate_simple_sql(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
|
||||
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
@@ -314,8 +239,9 @@ def generate_simple_sql(
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
sql_statement = sql_statement.replace("relationship_table", rel_temp_view)
|
||||
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
|
||||
sql_statement = sql_statement.replace(
|
||||
"kg_relationship", current_tenant_rel_view_name
|
||||
)
|
||||
|
||||
reasoning = (
|
||||
cleaned_response.split("<reasoning>")[1]
|
||||
@@ -324,101 +250,75 @@ def generate_simple_sql(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: restructure with broader node rework
|
||||
logger.error(f"Error in SQL generation: {e}")
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
raise e
|
||||
|
||||
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
|
||||
# Correction if needed:
|
||||
logger.debug(f"A3 - sql_statement: {sql_statement}")
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
# Correction if needed:
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
)
|
||||
|
||||
raise e
|
||||
raise e
|
||||
|
||||
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
|
||||
|
||||
# Get SQL for source documents
|
||||
|
||||
source_documents_sql = None
|
||||
source_documents_sql = _get_source_documents(
|
||||
sql_statement,
|
||||
llm=primary_llm,
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
|
||||
if (
|
||||
state.query_type == KGRelationshipDetection.RELATIONSHIPS.value
|
||||
or _sql_is_aggregate_query(sql_statement)
|
||||
):
|
||||
source_documents_sql = _get_source_documents(
|
||||
sql_statement,
|
||||
llm=primary_llm,
|
||||
focus=state.query_type,
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
)
|
||||
|
||||
if source_documents_sql and ent_temp_view:
|
||||
source_documents_sql = source_documents_sql.replace(
|
||||
"entity_table", ent_temp_view
|
||||
)
|
||||
|
||||
if source_documents_sql and rel_temp_view:
|
||||
source_documents_sql = source_documents_sql.replace(
|
||||
"relationship_table", rel_temp_view
|
||||
)
|
||||
|
||||
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
|
||||
logger.info(f"A3 source_documents_sql: {source_documents_sql}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
|
||||
# check sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
sql_statement, rel_temp_view, ent_temp_view
|
||||
)
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
@@ -440,16 +340,8 @@ def generate_simple_sql(
|
||||
raise e
|
||||
|
||||
source_document_results = None
|
||||
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
|
||||
# check source document sql, just in case
|
||||
_raise_error_if_sql_fails_problem_test(
|
||||
source_documents_sql, rel_temp_view, ent_temp_view
|
||||
)
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
|
||||
try:
|
||||
result = db_session.execute(text(source_documents_sql))
|
||||
rows = result.fetchall()
|
||||
@@ -460,29 +352,18 @@ def generate_simple_sql(
|
||||
]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
# TODO: raise error on frontend
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
# individualized_query_results = None
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
|
||||
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
|
||||
# source documents should be returned for entity queries
|
||||
source_document_results = [
|
||||
x["source_document"]
|
||||
for x in query_results
|
||||
if "source_document" in x
|
||||
]
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
|
||||
drop_views(
|
||||
allowed_docs_view_name=doc_temp_view,
|
||||
kg_relationships_view_name=rel_temp_view,
|
||||
kg_entity_view_name=ent_temp_view,
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=state.kg_doc_temp_view_name,
|
||||
kg_relationships_view_name=state.kg_rel_temp_view_name,
|
||||
)
|
||||
|
||||
logger.debug(f"A3 - Number of query_results: {len(query_results)}")
|
||||
logger.info(f"A3 - Number of query_results: {len(query_results)}")
|
||||
|
||||
# Stream out reasoning and SQL query
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user