Compare commits

..

6 Commits

Author SHA1 Message Date
joachim-danswer
0e38898c7a fixed migration 2025-06-06 16:05:01 -07:00
joachim-danswer
ce6a597eca mt-test observations/fixes 2025-06-06 13:42:27 -07:00
joachim-danswer
d251ba40ae update migration 2025-06-06 08:52:22 -07:00
joachim-danswer
26395d81c9 path fix 2025-06-05 22:16:39 -07:00
joachim-danswer
e1a3e11ec9 github error correction 1 2025-06-05 22:11:14 -07:00
joachim-danswer
e013711664 Initial Knowledge Graph Implementation, including:
- private schema upgrade
 - migration of tenant schema
 - extraction & clustering for KG
 - Graph for KG Answers
2025-06-05 19:46:50 -07:00
500 changed files with 8688 additions and 20595 deletions

View File

@@ -1,86 +0,0 @@
name: External Dependency Unit Tests
on:
merge_group:
pull_request:
branches: [main]
env:
# AWS
S3_AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}
S3_AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
jobs:
discover-test-dirs:
runs-on: ubuntu-latest
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Discover test directories
id: set-matrix
run: |
# Find all subdirectories in backend/tests/external_dependency_unit
dirs=$(find backend/tests/external_dependency_unit -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-dirs=$dirs" >> $GITHUB_OUTPUT
external-dependency-unit-tests:
needs: discover-test-dirs
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack up -d minio relational_db cache index
- name: Run migrations
run: |
cd backend
alembic upgrade head
- name: Run Tests for ${{ matrix.test-dir }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/external_dependency_unit/${{ matrix.test-dir }}

View File

@@ -16,21 +16,12 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -269,9 +260,6 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -16,20 +16,11 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -204,9 +195,6 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -54,6 +54,11 @@ jobs:
cd backend
mypy .
- name: Run ruff
run: |
cd backend
ruff .
- name: Check import order with reorder-python-imports
run: |
cd backend

View File

@@ -22,7 +22,6 @@ env:
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
@@ -50,9 +49,6 @@ env:
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Hubspot
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}

View File

@@ -15,9 +15,6 @@ jobs:
env:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
steps:
- name: Checkout code

View File

@@ -58,9 +58,3 @@ AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ran
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
# S3 File Store Configuration (MinIO for local development)
S3_ENDPOINT_URL=http://localhost:9004
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
S3_AWS_ACCESS_KEY_ID=minioadmin
S3_AWS_SECRET_ACCESS_KEY=minioadmin

View File

@@ -77,9 +77,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Install postgresql-client for easy manual tests
# Install it here to avoid it being cleaned up above
RUN apt-get update && apt-get install -y postgresql-client
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \

View File

@@ -20,44 +20,3 @@ To run all un-applied migrations:
To undo migrations:
`alembic downgrade -X`
where X is the number of migrations you want to undo from the current state
### Multi-tenant migrations
For multi-tenant deployments, you can use additional options:
**Upgrade all tenants:**
```bash
alembic -x upgrade_all_tenants=true upgrade head
```
**Upgrade specific schemas:**
```bash
# Single schema
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head
# Multiple schemas (comma-separated)
alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head
```
**Upgrade tenants within an alphabetical range:**
```bash
# Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200)
alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head
# Upgrade tenants starting from position 1000 alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head
# Upgrade first 500 tenants alphabetically
alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head
```
**Continue on error (for batch operations):**
```bash
alembic -x upgrade_all_tenants=true -x continue=true upgrade head
```
The tenant range filtering works by:
1. Sorting tenant IDs alphabetically
2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.)
3. Filtering to the specified range of positions
4. Non-tenant schemas (like 'public') are always included

View File

@@ -1,12 +1,12 @@
from typing import Any, Literal
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
@@ -21,14 +21,10 @@ from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
TENANT_ID_PREFIX,
)
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine import SqlEngine
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
@@ -73,67 +69,15 @@ def include_object(
return True
def filter_tenants_by_range(
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
) -> list[str]:
"""
Filter tenant IDs by alphabetical position range.
Args:
tenant_ids: List of tenant IDs to filter
start_range: Starting position in alphabetically sorted list (1-based, inclusive)
end_range: Ending position in alphabetically sorted list (1-based, inclusive)
Returns:
Filtered list of tenant IDs in their original order
"""
if start_range is None and end_range is None:
return tenant_ids
# Separate tenant IDs from non-tenant schemas
tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)]
non_tenant_schemas = [
tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX)
]
# Sort tenant schemas alphabetically.
# NOTE: can cause missed schemas if a schema is created in between workers
# fetching of all tenant IDs. We accept this risk for now. Just re-running
# the migration will fix the issue.
sorted_tenant_schemas = sorted(tenant_schemas)
# Apply range filtering (0-based indexing)
start_idx = start_range if start_range is not None else 0
end_idx = end_range if end_range is not None else len(sorted_tenant_schemas)
# Ensure indices are within bounds
start_idx = max(0, start_idx)
end_idx = min(len(sorted_tenant_schemas), end_idx)
# Get the filtered tenant schemas
filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx]
# Combine with non-tenant schemas and preserve original order
filtered_tenants = []
for tenant_id in tenant_ids:
if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas:
filtered_tenants.append(tenant_id)
return filtered_tenants
def get_schema_options() -> (
tuple[bool, bool, bool, int | None, int | None, list[str] | None]
):
def get_schema_options() -> tuple[str, bool, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
if "=" in arg:
key, value = arg.split("=", 1)
x_args[key.strip()] = value.strip()
else:
raise ValueError(f"Invalid argument: {arg}")
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
@@ -141,81 +85,17 @@ def get_schema_options() -> (
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
# Tenant range filtering
tenant_range_start = None
tenant_range_end = None
if "tenant_range_start" in x_args:
try:
tenant_range_start = int(x_args["tenant_range_start"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer."
)
if "tenant_range_end" in x_args:
try:
tenant_range_end = int(x_args["tenant_range_end"])
except ValueError:
raise ValueError(
f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer."
)
# Validate range
if tenant_range_start is not None and tenant_range_end is not None:
if tenant_range_start > tenant_range_end:
raise ValueError(
f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})"
)
# Specific schema names filtering (replaces both schema_name and the old tenant_ids approach)
schemas = None
if "schemas" in x_args:
schema_names_str = x_args["schemas"].strip()
if schema_names_str:
# Split by comma and strip whitespace
schemas = [
name.strip() for name in schema_names_str.split(",") if name.strip()
]
if schemas:
logger.info(f"Specific schema names specified: {schemas}")
# Validate that only one method is used at a time
range_filtering = tenant_range_start is not None or tenant_range_end is not None
specific_filtering = schemas is not None and len(schemas) > 0
if range_filtering and specific_filtering:
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
raise ValueError(
"Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) "
"and specific schema filtering (schemas) at the same time. "
"Please use only one filtering method."
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."
)
if upgrade_all_tenants and specific_filtering:
raise ValueError(
"Cannot use both upgrade_all_tenants=true and schemas at the same time. "
"Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas."
)
# If any filtering parameters are specified, we're not doing the default single schema migration
if range_filtering:
upgrade_all_tenants = True
# Validate multi-tenant requirements
if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering:
raise ValueError(
"In multi-tenant mode, you must specify either upgrade_all_tenants=true "
"or provide schemas. Cannot run default migration."
)
return (
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
)
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
def do_run_migrations(
@@ -262,17 +142,12 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
@@ -289,50 +164,12 @@ async def run_async_migrations() -> None:
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if schemas:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {schemas}")
i_schema = 0
num_schemas = len(schemas)
for schema in schemas:
i_schema += 1
logger.info(
f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}"
)
try:
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue=true is not set, raising exception!")
raise
logger.warning("--continue=true is set, continuing to next schema.")
elif upgrade_all_tenants:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)
if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)
i_tenant = 0
num_tenants = len(filtered_tenant_schemas)
for schema in filtered_tenant_schemas:
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
@@ -353,13 +190,17 @@ async def run_async_migrations() -> None:
logger.warning("--continue=true is set, continuing to next schema.")
else:
# This should not happen in the new design since we require either
# upgrade_all_tenants=true or schemas in multi-tenant mode
# and for non-multi-tenant mode, we should use schemas with the default schema
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or schemas for specific schemas."
)
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
@@ -380,37 +221,10 @@ def run_migrations_offline() -> None:
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)
(
create_schema,
upgrade_all_tenants,
continue_on_error,
tenant_range_start,
tenant_range_end,
schemas,
) = get_schema_options()
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()
if schemas:
# Use specific schema names directly without fetching all tenants
logger.info(f"Migrating specific schema names: {schemas}")
for schema in schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
elif upgrade_all_tenants:
if upgrade_all_tenants:
engine = create_async_engine(url)
if USE_IAM_AUTH:
@@ -424,19 +238,7 @@ def run_migrations_offline() -> None:
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
filtered_tenant_schemas = filter_tenants_by_range(
tenant_schemas, tenant_range_start, tenant_range_end
)
if tenant_range_start is not None or tenant_range_end is not None:
logger.info(
f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}"
)
logger.info(
f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}"
)
for schema in filtered_tenant_schemas:
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
@@ -452,12 +254,21 @@ def run_migrations_offline() -> None:
with context.begin_transaction():
context.run_migrations()
else:
# This should not happen in the new design
raise ValueError(
"No migration target specified. Use either upgrade_all_tenants=true for all tenants "
"or schemas for specific schemas."
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")

View File

@@ -1,121 +0,0 @@
"""rework-kg-config
Revision ID: 03bf8be6b53a
Revises: 65bc6e0f8500
Create Date: 2025-06-16 10:52:34.815335
"""
import json
from datetime import datetime
from datetime import timedelta
from sqlalchemy.dialects import postgresql
from sqlalchemy import text
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03bf8be6b53a"
down_revision = "65bc6e0f8500"
branch_labels = None
depends_on = None
def upgrade() -> None:
# get current config
current_configs = (
op.get_bind()
.execute(text("SELECT kg_variable_name, kg_variable_values FROM kg_config"))
.all()
)
current_config_dict = {
config.kg_variable_name: (
config.kg_variable_values[0]
if config.kg_variable_name
not in ("KG_VENDOR_DOMAINS", "KG_IGNORE_EMAIL_DOMAINS")
else config.kg_variable_values
)
for config in current_configs
if config.kg_variable_values
}
# not using the KGConfigSettings model here in case it changes in the future
kg_config_settings = json.dumps(
{
"KG_EXPOSED": current_config_dict.get("KG_EXPOSED", False),
"KG_ENABLED": current_config_dict.get("KG_ENABLED", False),
"KG_VENDOR": current_config_dict.get("KG_VENDOR", None),
"KG_VENDOR_DOMAINS": current_config_dict.get("KG_VENDOR_DOMAINS", []),
"KG_IGNORE_EMAIL_DOMAINS": current_config_dict.get(
"KG_IGNORE_EMAIL_DOMAINS", []
),
"KG_COVERAGE_START": current_config_dict.get(
"KG_COVERAGE_START",
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
),
"KG_MAX_COVERAGE_DAYS": current_config_dict.get("KG_MAX_COVERAGE_DAYS", 90),
"KG_MAX_PARENT_RECURSION_DEPTH": current_config_dict.get(
"KG_MAX_PARENT_RECURSION_DEPTH", 2
),
"KG_BETA_PERSONA_ID": current_config_dict.get("KG_BETA_PERSONA_ID", None),
}
)
op.execute(
f"INSERT INTO key_value_store (key, value) VALUES ('kg_config', '{kg_config_settings}')"
)
# drop kg config table
op.drop_table("kg_config")
def downgrade() -> None:
# get current config
current_config_dict = {
"KG_EXPOSED": False,
"KG_ENABLED": False,
"KG_VENDOR": [],
"KG_VENDOR_DOMAINS": [],
"KG_IGNORE_EMAIL_DOMAINS": [],
"KG_COVERAGE_START": (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"),
"KG_MAX_COVERAGE_DAYS": 90,
"KG_MAX_PARENT_RECURSION_DEPTH": 2,
}
current_configs = (
op.get_bind()
.execute(text("SELECT value FROM key_value_store WHERE key = 'kg_config'"))
.one_or_none()
)
if current_configs is not None:
current_config_dict.update(current_configs[0])
insert_values = [
{
"kg_variable_name": name,
"kg_variable_values": (
[str(val).lower() if isinstance(val, bool) else str(val)]
if not isinstance(val, list)
else val
),
}
for name, val in current_config_dict.items()
]
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
)
op.bulk_insert(
sa.table(
"kg_config",
sa.column("kg_variable_name", sa.String),
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
),
insert_values,
)
op.execute("DELETE FROM key_value_store WHERE key = 'kg_config'")

View File

@@ -1,136 +0,0 @@
"""update_kg_trigger_functions
Revision ID: 36e9220ab794
Revises: c9e2cd766c29
Create Date: 2025-06-22 17:33:25.833733
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
down_revision = "c9e2cd766c29"
branch_labels = None
depends_on = None
def _get_tenant_contextvar(session: Session) -> str:
"""Get the current schema for the migration"""
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
if isinstance(current_tenant, str):
return current_tenant
else:
raise ValueError("Current tenant is not a string")
def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
# Create kg_entity trigger to update kg_entity.name and its trigrams
tenant_id = _get_tenant_contextvar(session)
alphanum_pattern = r"[^a-z0-9]+"
truncate_length = 1000
function = "update_kg_entity_name"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
name text;
cleaned_name text;
BEGIN
-- Set name to semantic_id if document_id is not NULL
IF NEW.document_id IS NOT NULL THEN
SELECT lower(semantic_id) INTO name
FROM "{tenant_id}".document
WHERE id = NEW.document_id;
ELSE
name = lower(NEW.name);
END IF;
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity')
op.execute(
f"""
CREATE TRIGGER {trigger}
BEFORE INSERT OR UPDATE OF name
ON "{tenant_id}".kg_entity
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
# Create kg_entity trigger to update kg_entity.name and its trigrams
function = "update_kg_entity_name_from_doc"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION "{tenant_id}".{function}()
RETURNS TRIGGER AS $$
DECLARE
doc_name text;
cleaned_name text;
BEGIN
doc_name = lower(NEW.semantic_id);
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
doc_name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams for all entities referencing this document
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document')
op.execute(
f"""
CREATE TRIGGER {trigger}
AFTER UPDATE OF semantic_id
ON "{tenant_id}".document
FOR EACH ROW
EXECUTE FUNCTION "{tenant_id}".{function}();
"""
)
def downgrade() -> None:
pass

View File

@@ -21,14 +21,22 @@ depends_on = None
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Adds indexes to both chat_message and chat_session tables for comprehensive search
# 3. Note: CONCURRENTLY was removed due to operational issues
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
@@ -44,9 +52,12 @@ def upgrade() -> None:
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chat_message_tsv
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
@@ -61,9 +72,12 @@ def upgrade() -> None:
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX IF NOT EXISTS idx_chat_session_desc_tsv
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
@@ -71,9 +85,12 @@ def upgrade() -> None:
def downgrade() -> None:
# Drop the indexes first
op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;")
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")

View File

@@ -15,7 +15,6 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
@@ -34,8 +33,11 @@ def upgrade() -> None:
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
# environment variables MUST be set. Otherwise, an exception will be raised.
print("MULTI_TENANT: ", MULTI_TENANT)
if not MULTI_TENANT:
print("Single tenant mode")
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
@@ -66,6 +68,7 @@ def upgrade() -> None:
)
# Grant usage on current schema to readonly user
op.execute(
text(
f"""
@@ -467,11 +470,11 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
"ON kg_entity USING GIN (name public.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
"ON kg_entity USING GIN (name_trigrams)"
)
@@ -508,7 +511,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = public.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -553,7 +556,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = public.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
@@ -573,9 +576,13 @@ def upgrade() -> None:
"""
)
def downgrade() -> None:
# Drop all views that start with 'kg_'
op.execute(
"""
@@ -625,8 +632,9 @@ def downgrade() -> None:
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
# Drop index
op.execute("DROP INDEX IF EXISTS idx_kg_entity_clustering_trigrams")
op.execute("DROP INDEX IF EXISTS idx_kg_entity_normalization_trigrams")
op.execute("COMMIT") # Commit to allow CONCURRENTLY
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")
@@ -643,21 +651,6 @@ def downgrade() -> None:
op.drop_column("document", "kg_processing_time")
op.drop_table("kg_config")
# Revoke usage on current schema for the readonly user
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
if not MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
@@ -678,4 +671,20 @@ def downgrade() -> None:
"""
)
)
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
else:
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)

View File

@@ -1,90 +0,0 @@
"""add stale column to external user group tables
Revision ID: 58c50ef19f08
Revises: 7b9b952abdf6
Create Date: 2025-06-25 14:08:14.162380
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "58c50ef19f08"
down_revision = "7b9b952abdf6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add the stale column with default value False to user__external_user_group_id
op.add_column(
"user__external_user_group_id",
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
)
# Create index for efficient querying of stale rows by cc_pair_id
op.create_index(
"ix_user__external_user_group_id_cc_pair_id_stale",
"user__external_user_group_id",
["cc_pair_id", "stale"],
unique=False,
)
# Create index for efficient querying of all stale rows
op.create_index(
"ix_user__external_user_group_id_stale",
"user__external_user_group_id",
["stale"],
unique=False,
)
# Add the stale column with default value False to public_external_user_group
op.add_column(
"public_external_user_group",
sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"),
)
# Create index for efficient querying of stale rows by cc_pair_id
op.create_index(
"ix_public_external_user_group_cc_pair_id_stale",
"public_external_user_group",
["cc_pair_id", "stale"],
unique=False,
)
# Create index for efficient querying of all stale rows
op.create_index(
"ix_public_external_user_group_stale",
"public_external_user_group",
["stale"],
unique=False,
)
def downgrade() -> None:
# Drop the indices for public_external_user_group first
op.drop_index(
"ix_public_external_user_group_stale", table_name="public_external_user_group"
)
op.drop_index(
"ix_public_external_user_group_cc_pair_id_stale",
table_name="public_external_user_group",
)
# Drop the stale column from public_external_user_group
op.drop_column("public_external_user_group", "stale")
# Drop the indices for user__external_user_group_id
op.drop_index(
"ix_user__external_user_group_id_stale",
table_name="user__external_user_group_id",
)
op.drop_index(
"ix_user__external_user_group_id_cc_pair_id_stale",
table_name="user__external_user_group_id",
)
# Drop the stale column from user__external_user_group_id
op.drop_column("user__external_user_group_id", "stale")

View File

@@ -1,41 +0,0 @@
"""remove kg subtype from db
Revision ID: 65bc6e0f8500
Revises: cec7ec36c505
Create Date: 2025-06-13 10:04:27.705976
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "65bc6e0f8500"
down_revision = "cec7ec36c505"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("kg_entity", "entity_class")
op.drop_column("kg_entity", "entity_subtype")
op.drop_column("kg_entity_extraction_staging", "entity_class")
op.drop_column("kg_entity_extraction_staging", "entity_subtype")
def downgrade() -> None:
op.add_column(
"kg_entity_extraction_staging",
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
)
op.add_column(
"kg_entity_extraction_staging",
sa.Column("entity_class", sa.String(), nullable=True, index=True),
)
op.add_column(
"kg_entity", sa.Column("entity_subtype", sa.String(), nullable=True, index=True)
)
op.add_column(
"kg_entity", sa.Column("entity_class", sa.String(), nullable=True, index=True)
)

View File

@@ -1,318 +0,0 @@
"""update-entities
Revision ID: 7b9b952abdf6
Revises: 36e9220ab794
Create Date: 2025-06-23 20:24:08.139201
"""
import json
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7b9b952abdf6"
down_revision = "36e9220ab794"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
# new entity type metadata_attribute_conversion
new_entity_type_conversion = {
"LINEAR": {
"team": {"name": "team", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"priority": {
"name": "priority",
"keep": True,
"implication_property": None,
},
"estimate": {
"name": "estimate",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"started_at": {
"name": "started_at",
"keep": True,
"implication_property": None,
},
"completed_at": {
"name": "completed_at",
"keep": True,
"implication_property": None,
},
"due_date": {
"name": "due_date",
"keep": True,
"implication_property": None,
},
"creator": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignee": {
"name": "assignee",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"JIRA": {
"issuetype": {
"name": "subtype",
"keep": True,
"implication_property": None,
},
"status": {"name": "status", "keep": True, "implication_property": None},
"priority": {
"name": "priority",
"keep": True,
"implication_property": None,
},
"project_name": {
"name": "project",
"keep": True,
"implication_property": None,
},
"created": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"resolution_date": {
"name": "completed_at",
"keep": True,
"implication_property": None,
},
"duedate": {"name": "due_date", "keep": True, "implication_property": None},
"reporter_email": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignee_email": {
"name": "assignee",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
"key": {"name": "key", "keep": True, "implication_property": None},
"parent": {"name": "parent", "keep": True, "implication_property": None},
},
"GITHUB_PR": {
"repo": {"name": "repository", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"num_commits": {
"name": "num_commits",
"keep": True,
"implication_property": None,
},
"num_files_changed": {
"name": "num_files_changed",
"keep": True,
"implication_property": None,
},
"labels": {"name": "labels", "keep": True, "implication_property": None},
"merged": {"name": "merged", "keep": True, "implication_property": None},
"merged_at": {
"name": "merged_at",
"keep": True,
"implication_property": None,
},
"closed_at": {
"name": "closed_at",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated_at": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"user": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignees": {
"name": "assignees",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"GITHUB_ISSUE": {
"repo": {"name": "repository", "keep": True, "implication_property": None},
"state": {"name": "state", "keep": True, "implication_property": None},
"labels": {"name": "labels", "keep": True, "implication_property": None},
"closed_at": {
"name": "closed_at",
"keep": True,
"implication_property": None,
},
"created_at": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"updated_at": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"user": {
"name": "creator",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_creator_of",
},
},
"assignees": {
"name": "assignees",
"keep": False,
"implication_property": {
"implied_entity_type": "from_email",
"implied_relationship_name": "is_assignee_of",
},
},
},
"FIREFLIES": {},
"ACCOUNT": {},
"OPPORTUNITY": {
"name": {"name": "name", "keep": True, "implication_property": None},
"stage_name": {"name": "stage", "keep": True, "implication_property": None},
"type": {"name": "type", "keep": True, "implication_property": None},
"amount": {"name": "amount", "keep": True, "implication_property": None},
"fiscal_year": {
"name": "fiscal_year",
"keep": True,
"implication_property": None,
},
"fiscal_quarter": {
"name": "fiscal_quarter",
"keep": True,
"implication_property": None,
},
"is_closed": {
"name": "is_closed",
"keep": True,
"implication_property": None,
},
"close_date": {
"name": "close_date",
"keep": True,
"implication_property": None,
},
"probability": {
"name": "close_probability",
"keep": True,
"implication_property": None,
},
"created_date": {
"name": "created_at",
"keep": True,
"implication_property": None,
},
"last_modified_date": {
"name": "updated_at",
"keep": True,
"implication_property": None,
},
"account": {
"name": "account",
"keep": False,
"implication_property": {
"implied_entity_type": "ACCOUNT",
"implied_relationship_name": "is_account_of",
},
},
},
"VENDOR": {},
"EMPLOYEE": {},
}
current_entity_types = conn.execute(
sa.text("SELECT id_name, attributes from kg_entity_type")
).all()
for entity_type, attributes in current_entity_types:
# delete removed entity types
if entity_type not in new_entity_type_conversion:
op.execute(
sa.text(f"DELETE FROM kg_entity_type WHERE id_name = '{entity_type}'")
)
continue
# update entity type attributes
if "metadata_attributes" in attributes:
del attributes["metadata_attributes"]
attributes["metadata_attribute_conversion"] = new_entity_type_conversion[
entity_type
]
attributes_str = json.dumps(attributes).replace("'", "''")
op.execute(
sa.text(
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
f"WHERE id_name = '{entity_type}'"
),
)
def downgrade() -> None:
conn = op.get_bind()
current_entity_types = conn.execute(
sa.text("SELECT id_name, attributes from kg_entity_type")
).all()
for entity_type, attributes in current_entity_types:
conversion = {}
if "metadata_attribute_conversion" in attributes:
conversion = attributes.pop("metadata_attribute_conversion")
attributes["metadata_attributes"] = {
attr: prop["name"] for attr, prop in conversion.items() if prop["keep"]
}
attributes_str = json.dumps(attributes).replace("'", "''")
op.execute(
sa.text(
f"UPDATE kg_entity_type SET attributes = '{attributes_str}'"
f"WHERE id_name = '{entity_type}'"
),
)

View File

@@ -1,312 +0,0 @@
"""modify_file_store_for_external_storage
Revision ID: c9e2cd766c29
Revises: 03bf8be6b53a
Create Date: 2025-06-13 14:02:09.867679
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy import text
from typing import cast, Any
from botocore.exceptions import ClientError
from onyx.db._deprecated.pg_file_store import delete_lobj_by_id, read_lobj
from onyx.file_store.file_store import get_s3_file_store
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
# revision identifiers, used by Alembic.
revision = "c9e2cd766c29"
down_revision = "03bf8be6b53a"
branch_labels = None
depends_on = None
def upgrade() -> None:
try:
# Modify existing file_store table to support external storage
op.rename_table("file_store", "file_record")
# Make lobj_oid nullable (for external storage files)
op.alter_column("file_record", "lobj_oid", nullable=True)
# Add external storage columns with generic names
op.add_column(
"file_record", sa.Column("bucket_name", sa.String(), nullable=True)
)
op.add_column(
"file_record", sa.Column("object_key", sa.String(), nullable=True)
)
# Add timestamps for tracking
op.add_column(
"file_record",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.add_column(
"file_record",
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.alter_column("file_record", "file_name", new_column_name="file_id")
except Exception as e:
if "does not exist" in str(e) or 'relation "file_store" does not exist' in str(
e
):
print(
f"Ran into error - {e}. Likely means we had a partial success in the past, continuing..."
)
else:
raise
print(
"External storage configured - migrating files from PostgreSQL to external storage..."
)
# if we fail midway through this, we'll have a partial success. Running the migration
# again should allow us to continue.
_migrate_files_to_external_storage()
print("File migration completed successfully!")
# Remove lobj_oid column
op.drop_column("file_record", "lobj_oid")
def downgrade() -> None:
"""Revert schema changes and migrate files from external storage back to PostgreSQL large objects."""
print(
"Reverting to PostgreSQL-backed file store migrating files from external storage …"
)
# 1. Ensure `lobj_oid` exists on the current `file_record` table (nullable for now).
op.add_column("file_record", sa.Column("lobj_oid", sa.Integer(), nullable=True))
# 2. Move content from external storage back into PostgreSQL large objects (table is still
# called `file_record` so application code continues to work during the copy).
try:
_migrate_files_to_postgres()
except Exception:
print("Error during downgrade migration, rolling back …")
op.drop_column("file_record", "lobj_oid")
raise
# 3. After migration every row should now have `lobj_oid` populated mark NOT NULL.
op.alter_column("file_record", "lobj_oid", nullable=False)
# 4. Remove columns that are only relevant to external storage.
op.drop_column("file_record", "updated_at")
op.drop_column("file_record", "created_at")
op.drop_column("file_record", "object_key")
op.drop_column("file_record", "bucket_name")
# 5. Rename `file_id` back to `file_name` (still on `file_record`).
op.alter_column("file_record", "file_id", new_column_name="file_name")
# 6. Finally, rename the table back to its original name expected by the legacy codebase.
op.rename_table("file_record", "file_store")
print(
"Downgrade migration completed files are now stored inside PostgreSQL again."
)
# -----------------------------------------------------------------------------
# Helper: migrate from external storage (S3/MinIO) back into PostgreSQL large objects
def _migrate_files_to_postgres() -> None:
"""Move any files whose content lives in external S3-compatible storage back into PostgreSQL.
The logic mirrors *inverse* of `_migrate_files_to_external_storage` used on upgrade.
"""
# Obtain DB session from Alembic context
bind = op.get_bind()
session = Session(bind=bind)
# Fetch rows that have external storage pointers (bucket/object_key not NULL)
result = session.execute(
text(
"SELECT file_id, bucket_name, object_key FROM file_record "
"WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL"
)
)
files_to_migrate = [row[0] for row in result.fetchall()]
total_files = len(files_to_migrate)
if total_files == 0:
print("No files found in external storage to migrate back to PostgreSQL.")
return
print(f"Found {total_files} files to migrate back to PostgreSQL large objects.")
_set_tenant_contextvar(session)
migrated_count = 0
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store(db_session=session)
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
# Read file content from external storage (always binary)
try:
file_io = external_store.read_file(
file_id=file_id, mode="b", use_tempfile=True
)
file_io.seek(0)
# Import lazily to avoid circular deps at Alembic runtime
from onyx.db._deprecated.pg_file_store import (
create_populate_lobj,
) # noqa: E402
# Create new Postgres large object and populate it
lobj_oid = create_populate_lobj(content=file_io, db_session=session)
# Update DB row: set lobj_oid, clear bucket/object_key
session.execute(
text(
"UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, "
"object_key = NULL WHERE file_id = :file_id"
),
{"lobj_oid": lobj_oid, "file_id": file_id},
)
except ClientError as e:
if "NoSuchKey" in str(e):
print(
f"File {file_id} not found in external storage. Deleting from database."
)
session.execute(
text("DELETE FROM file_record WHERE file_id = :file_id"),
{"file_id": file_id},
)
else:
raise
migrated_count += 1
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
# Flush the SQLAlchemy session so statements are sent to the DB, but **do not**
# commit the transaction. The surrounding Alembic migration will commit once
# the *entire* downgrade succeeds. This keeps the whole downgrade atomic and
# avoids leaving the database in a partially-migrated state if a later schema
# operation fails.
session.flush()
print(
f"Migration back to PostgreSQL completed: {migrated_count} files staged for commit."
)
def _migrate_files_to_external_storage() -> None:
"""Migrate files from PostgreSQL large objects to external storage"""
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store(db_session=session)
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(
text(
"SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL "
"AND bucket_name IS NULL AND object_key IS NULL"
)
)
files_to_migrate = [row[0] for row in result.fetchall()]
total_files = len(files_to_migrate)
if total_files == 0:
print("No files found in PostgreSQL storage to migrate.")
return
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
_set_tenant_contextvar(session)
migrated_count = 0
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
# Read file record to get metadata
file_record = session.execute(
text("SELECT * FROM file_record WHERE file_id = :file_id"),
{"file_id": file_id},
).fetchone()
if file_record is None:
print(f"File {file_id} not found in PostgreSQL storage.")
continue
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
# Read file content from PostgreSQL
try:
file_content = read_lobj(
lobj_id, db_session=session, mode="b", use_tempfile=True
)
except Exception as e:
if "large object" in str(e) and "does not exist" in str(e):
print(f"File {file_id} not found in PostgreSQL storage.")
continue
else:
raise
# Handle file_metadata type conversion
file_metadata = None
if file_metadata is not None:
if isinstance(file_metadata, dict):
file_metadata = file_metadata
else:
# Convert other types to dict if possible, otherwise None
try:
file_metadata = dict(file_record.file_metadata) # type: ignore
except (TypeError, ValueError):
file_metadata = None
# Save to external storage (this will handle the database record update and cleanup)
# NOTE: this WILL .commit() the transaction.
external_store.save_file(
file_id=file_id,
content=file_content,
display_name=file_record.display_name,
file_origin=file_record.file_origin,
file_type=file_record.file_type,
file_metadata=file_metadata,
)
delete_lobj_by_id(lobj_id, db_session=session)
migrated_count += 1
print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}")
# See note above flush but do **not** commit so the outer Alembic transaction
# controls atomicity.
session.flush()
print(
f"Migration completed: {migrated_count} files staged for commit to external storage."
)
def _set_tenant_contextvar(session: Session) -> None:
"""Set the tenant contextvar to the default schema"""
current_tenant = session.execute(text("SELECT current_schema()")).scalar()
print(f"Migrating files for tenant: {current_tenant}")
CURRENT_TENANT_ID_CONTEXTVAR.set(current_tenant)

View File

@@ -1,29 +0,0 @@
"""kgentity_parent
Revision ID: cec7ec36c505
Revises: 495cb26ce93e
Create Date: 2025-06-07 20:07:46.400770
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "cec7ec36c505"
down_revision = "495cb26ce93e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"kg_entity",
sa.Column("parent_key", sa.String(), nullable=True, index=True),
)
# NOTE: you will have to reindex the KG after this migration as the parent_key will be null
def downgrade() -> None:
op.drop_column("kg_entity", "parent_key")

View File

@@ -11,7 +11,7 @@ import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.

View File

@@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine import build_connection_string
from onyx.db.models import PublicBase
# this is the Alembic Config object, which provides

View File

@@ -77,4 +77,3 @@ def downgrade() -> None:
"""
)
)
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))

View File

@@ -4,7 +4,10 @@ from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.external_perm import fetch_public_external_group_ids
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
@@ -15,10 +18,6 @@ from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_access_for_document(
@@ -71,15 +70,9 @@ def _get_access_for_documents(
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
if source is None:
logger.error(f"Document {document_id} has no source")
continue
perm_sync_config = get_source_perm_sync_config(source)
is_only_censored = (
perm_sync_config
and perm_sync_config.censoring_config is not None
and perm_sync_config.doc_sync_config is None
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)
ext_u_emails = (

View File

@@ -1,10 +1,12 @@
import csv
import io
from datetime import datetime
from datetime import timezone
from celery import shared_task
from celery import Task
from ee.onyx.background.task_name_builders import query_history_task_name
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
@@ -16,10 +18,11 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.db.tasks import register_task
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -34,19 +37,13 @@ logger = setup_logger()
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
def export_query_history_task(self: Task, *, start: datetime, end: datetime) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
start_time = datetime.now(tz=timezone.utc)
stream = io.StringIO()
writer = csv.DictWriter(
stream,
@@ -56,9 +53,12 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
register_task(
db_session=db_session,
task_name=query_history_task_name(start=start, end=end),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
snapshot_generator = fetch_and_process_chat_session_history(
@@ -92,6 +92,7 @@ def export_query_history_task(
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
file_name=report_name,
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
@@ -101,7 +102,6 @@ def export_query_history_task(
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(

View File

@@ -13,7 +13,7 @@ from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task

View File

@@ -20,36 +20,39 @@ from shared_configs.configs import MULTI_TENANT
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
},
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
{
"name": "export-query-history-cleanup-task",
"task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.CSV_GENERATION,
},
},
},
]
]
)
ee_tasks_to_schedule: list[dict] = []

View File

@@ -6,7 +6,7 @@ from celery import shared_task
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import delete_task_with_id
from onyx.utils.logger import setup_logger

View File

@@ -13,7 +13,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine import get_all_tenant_ids
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST

View File

@@ -21,16 +21,20 @@ from tenacity import retry_if_exception
from tenacity import stop_after_delay
from tenacity import wait_random_exponential
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.onyx.external_permissions.sync_params import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -48,8 +52,8 @@ from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -74,7 +78,6 @@ from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -89,24 +92,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 300 # seconds
if not MULTI_TENANT:
return base_expiration
try:
beat_multiplier = OnyxRuntime.get_beat_multiplier()
except Exception:
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
return int(base_expiration * beat_multiplier)
"""Jobs / utils for kicking off doc permissions sync tasks."""
@@ -120,29 +105,16 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
if sync_config is None:
logger.error(f"No sync config found for {cc_pair.connector.source}")
return False
if sync_config.doc_sync_config is None:
logger.error(f"No doc sync config found for {cc_pair.connector.source}")
return False
# if indexing also does perm sync, don't start running doc_sync until at
# least one indexing is done
if (
sync_config.doc_sync_config.initial_index_should_sync
and cc_pair.last_successful_index_time is None
):
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = sync_config.doc_sync_config.doc_sync_frequency
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
if not source_sync_period:
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
@@ -214,11 +186,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
"Exception while validating permission sync fences"
)
r.set(
OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES,
1,
ex=_get_fence_validation_block_expiration(),
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
@@ -449,7 +417,6 @@ def connector_permission_sync_generator_task(
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
cc_pair.access_type,
db_session,
enforce_creation=False,
)
@@ -465,15 +432,11 @@ def connector_permission_sync_generator_task(
raise
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
logger.error(f"No sync config found for {source_type}")
return None
if sync_config.doc_sync_config is None:
if sync_config.censoring_config:
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
return None
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
@@ -505,7 +468,6 @@ def connector_permission_sync_generator_task(
credential_id=cc_pair.credential.id,
)
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
document_external_accesses = doc_sync_func(
cc_pair, fetch_all_existing_docs_fn, callback
)
@@ -622,6 +584,91 @@ def document_update_permissions(
return True
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
# large permissions through celery (over 1MB in size)
# @shared_task(
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
# time_limit=LIGHT_TIME_LIMIT,
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
# bind=True,
# )
# def update_external_document_permissions_task(
# self: Task,
# tenant_id: str,
# serialized_doc_external_access: dict,
# source_string: str,
# connector_id: int,
# credential_id: int,
# ) -> bool:
# start = time.monotonic()
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
# document_external_access = DocExternalAccess.from_dict(
# serialized_doc_external_access
# )
# doc_id = document_external_access.doc_id
# external_access = document_external_access.external_access
# try:
# with get_session_with_current_tenant() as db_session:
# # Add the users to the DB if they don't exist
# batch_add_ext_perm_user_if_not_exists(
# db_session=db_session,
# emails=list(external_access.external_user_emails),
# continue_on_error=True,
# )
# # Then upsert the document's external permissions
# created_new_doc = upsert_document_external_perms(
# db_session=db_session,
# doc_id=doc_id,
# external_access=external_access,
# source_type=DocumentSource(source_string),
# )
# if created_new_doc:
# # If a new document was created, we associate it with the cc_pair
# upsert_document_by_connector_credential_pair(
# db_session=db_session,
# connector_id=connector_id,
# credential_id=credential_id,
# document_ids=[doc_id],
# )
# elapsed = time.monotonic() - start
# task_logger.info(
# f"connector_id={connector_id} "
# f"doc={doc_id} "
# f"action=update_permissions "
# f"elapsed={elapsed:.2f}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
# except Exception as e:
# error_msg = format_error_for_logging(e)
# task_logger.warning(
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
# )
# task_logger.exception(
# f"update_external_document_permissions_task exceptioned: "
# f"connector_id={connector_id} doc_id={doc_id}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
# finally:
# task_logger.info(
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
# )
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
# return False
# task_logger.info(
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
# )
# return True
def validate_permission_sync_fences(
tenant_id: str,
r: Redis,

View File

@@ -20,17 +20,15 @@ from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils imp
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.db.external_perm import mark_old_external_groups_as_stale
from ee.onyx.db.external_perm import remove_stale_external_groups
from ee.onyx.db.external_perm import upsert_external_groups
from ee.onyx.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.onyx.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.onyx.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.onyx.external_permissions.sync_params import (
get_all_cc_pair_agnostic_group_sync_sources,
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.error_logging import emit_background_error
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
@@ -42,8 +40,9 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -58,34 +57,19 @@ from onyx.redis.redis_connector_ext_group_sync import (
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_EXTERNAL_GROUP_BATCH_SIZE = 100
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 300 # seconds
if not MULTI_TENANT:
return base_expiration
try:
beat_multiplier = OnyxRuntime.get_beat_multiplier()
except Exception:
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
return int(base_expiration * beat_multiplier)
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@@ -105,20 +89,12 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
)
return False
sync_config = get_source_perm_sync_config(cc_pair.connector.source)
if sync_config is None:
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"no sync config found for {cc_pair.connector.source}"
)
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if sync_config.group_sync_config is None:
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"no group sync config found for {cc_pair.connector.source}"
f"no group sync function for {cc_pair.connector.source}"
)
return False
@@ -127,7 +103,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if last_ext_group_sync is None:
return True
source_sync_period = sync_config.group_sync_config.group_sync_frequency
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
@@ -167,8 +147,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# For some sources, we only want to sync one cc_pair per source type
for source in get_all_cc_pair_agnostic_group_sync_sources():
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session,
@@ -176,7 +157,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
)
# dedupe cc_pairs to only keep the first one
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
@@ -215,11 +197,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
"Exception while validating external group sync fences"
)
r.set(
OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES,
1,
ex=_get_fence_validation_block_expiration(),
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -398,12 +376,55 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
_perform_external_group_sync(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
)
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
logger.debug(f"New external user groups: {external_user_groups}")
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -445,81 +466,6 @@ def connector_external_group_sync_generator_task(
)
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
) -> None:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if sync_config.group_sync_config is None:
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
logger.info(
f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}"
)
mark_old_external_groups_as_stale(db_session, cc_pair_id)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_group_batch: list[ExternalUserGroup] = []
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
external_user_group_batch.append(external_user_group)
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
logger.debug(
f"New external user groups: {external_user_group_batch}"
)
upsert_external_groups(
db_session=db_session,
cc_pair_id=cc_pair_id,
external_groups=external_user_group_batch,
source=cc_pair.connector.source,
)
external_user_group_batch = []
if external_user_group_batch:
logger.debug(f"New external user groups: {external_user_group_batch}")
upsert_external_groups(
db_session=db_session,
cc_pair_id=cc_pair_id,
external_groups=external_user_group_batch,
source=cc_pair.connector.source,
)
except Exception as e:
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e
logger.info(
f"Removing stale external groups for {source_type} for cc_pair: {cc_pair_id}"
)
remove_stale_external_groups(db_session, cc_pair_id)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
def validate_external_group_sync_fences(
tenant_id: str,
celery_app: Celery,

View File

@@ -19,7 +19,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT

View File

@@ -53,16 +53,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
)
#####
# JIRA
#####
# In seconds, default is 30 minutes
JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
)
#####
# Google Drive
#####
@@ -81,15 +71,6 @@ SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
#####
# Teams
#####
# In seconds, default is 5 minutes
TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
####
# Celery Job Frequency
####

View File

@@ -1,28 +0,0 @@
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.interfaces import BaseConnector
def validate_confluence_perm_sync(connector: ConfluenceConnector) -> None:
"""
Validate that the connector is configured correctly for permissions syncing.
"""
def validate_drive_perm_sync(connector: GoogleDriveConnector) -> None:
"""
Validate that the connector is configured correctly for permissions syncing.
"""
def validate_perm_sync(connector: BaseConnector) -> None:
"""
Override this if your connector needs to validate permissions syncing.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
if isinstance(connector, ConfluenceConnector):
validate_confluence_perm_sync(connector)
elif isinstance(connector, GoogleDriveConnector):
validate_drive_perm_sync(connector)

View File

@@ -4,7 +4,6 @@ from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.access.utils import build_ext_group_name_for_onyx
@@ -63,41 +62,20 @@ def delete_public_external_group_for_cc_pair__no_commit(
)
def mark_old_external_groups_as_stale(
def replace_user__ext_group_for_cc_pair(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
update(User__ExternalUserGroupId)
.where(User__ExternalUserGroupId.cc_pair_id == cc_pair_id)
.values(stale=True)
)
db_session.execute(
update(PublicExternalUserGroup)
.where(PublicExternalUserGroup.cc_pair_id == cc_pair_id)
.values(stale=True)
)
def upsert_external_groups(
db_session: Session,
cc_pair_id: int,
external_groups: list[ExternalUserGroup],
group_defs: list[ExternalUserGroup],
source: DocumentSource,
) -> None:
"""
Performs a true upsert operation for external user groups:
- For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False
- For new groups, inserts them with stale=False
- For public groups, uses upsert logic as well
This function clears all existing external user group relations for a given cc_pair_id
and replaces them with the new group definitions and commits the changes.
"""
# If there are no groups to add, return early
if not external_groups:
return
# collect all emails from all groups to batch add all users at once for efficiency
all_group_member_emails = set()
for external_group in external_groups:
for external_group in group_defs:
for user_email in external_group.user_emails:
all_group_member_emails.add(user_email)
@@ -108,17 +86,26 @@ def upsert_external_groups(
emails=list(all_group_member_emails),
)
# map emails to ids
email_id_map = {user.email.lower(): user.id for user in all_group_members}
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
delete_public_external_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# Process each external group
for external_group in external_groups:
# map emails to ids
email_id_map = {user.email: user.id for user in all_group_members}
# use these ids to create new external user group relations relating group_id to user_ids
new_external_permissions: list[User__ExternalUserGroupId] = []
new_public_external_groups: list[PublicExternalUserGroup] = []
for external_group in group_defs:
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
# Handle user-group mappings
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email.lower())
if user_id is None:
@@ -127,71 +114,24 @@ def upsert_external_groups(
f" with email {user_email} not found"
)
continue
# Check if the user-group mapping already exists
existing_user_group = db_session.scalar(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id
== external_group_id,
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
)
)
if existing_user_group:
# Update existing record
existing_user_group.stale = False
else:
# Insert new record
new_user_group = User__ExternalUserGroupId(
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_user_group)
# Handle public group if needed
if external_group.gives_anyone_access:
# Check if the public group already exists
existing_public_group = db_session.scalar(
select(PublicExternalUserGroup).where(
PublicExternalUserGroup.external_user_group_id == external_group_id,
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
)
)
if existing_public_group:
# Update existing record
existing_public_group.stale = False
else:
# Insert new record
new_public_group = PublicExternalUserGroup(
if external_group.gives_anyone_access:
new_public_external_groups.append(
PublicExternalUserGroup(
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
stale=False,
)
db_session.add(new_public_group)
)
db_session.commit()
def remove_stale_external_groups(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
delete(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.cc_pair_id == cc_pair_id,
User__ExternalUserGroupId.stale.is_(True),
)
)
db_session.execute(
delete(PublicExternalUserGroup).where(
PublicExternalUserGroup.cc_pair_id == cc_pair_id,
PublicExternalUserGroup.stale.is_(True),
)
)
db_session.add_all(new_external_permissions)
db_session.add_all(new_public_external_groups)
db_session.commit()

View File

@@ -115,24 +115,11 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
report_name: str,
) -> IO:
"""
Get the usage report data from the file store.
Args:
db_session: The database session.
report_display_name: The display name of the usage report. Also assumes
that the file is stored with this as the ID in the file store.
Returns:
The usage report data.
"""
file_store = get_default_file_store(db_session)
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True
)
return file_store.read_file(file_name=report_name, mode="b", use_tempfile=True)
def write_usage_report(

View File

@@ -2,6 +2,3 @@
# Instead of setting a page to public, we just add this group so that the page
# is only accessible to users who have confluence accounts.
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
REQUEST_PAGINATION_LIMIT = 5000

View File

@@ -4,13 +4,20 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
"""
from collections.abc import Generator
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -18,8 +25,369 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
_REQUEST_PAGINATION_LIMIT = 5000
CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
viewspace_permissions = []
for permission_category in space_permissions:
if permission_category.get("type") == _VIEWSPACE_PERMISSION_TYPE:
viewspace_permissions.extend(
permission_category.get("spacePermissions", [])
)
is_public = False
user_names = set()
group_names = set()
for permission in viewspace_permissions:
user_name = permission.get("userName")
if user_name:
user_names.add(user_name)
group_name = permission.get("groupName")
if group_name:
group_names.add(group_name)
# It seems that if anonymous access is turned on for the site and space,
# then the space is publicly accessible.
# For confluence server, we make a group that contains all users
# that exist in confluence and then just add that group to the space permissions
# if anonymous access is turned on for the site and space or we set is_public = True
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
# that we can support confluence server deployments that want anonymous access
# to be public (we cant test this because its paywalled)
if user_name is None and group_name is None:
# Defaults to False
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
is_public = True
else:
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
user_emails = set()
for user_name in user_names:
user_email = get_user_email_from_username__server(confluence_client, user_name)
if user_email:
user_emails.add(user_email)
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_public,
)
def _get_cloud_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions_result = confluence_client.get_space(
space_key=space_key, expand="permissions"
)
space_permissions = space_permissions_result.get("permissions", [])
user_emails = set()
group_names = set()
is_externally_public = False
for permission in space_permissions:
subs = permission.get("subjects")
if subs:
# If there are subjects, then there are explicit users or groups with access
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
user_emails.add(email)
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
group_names.add(group_name)
else:
# If there are no subjects, then the permission is for everyone
if permission.get("operation", {}).get(
"operation"
) == "read" and permission.get("anonymousAccess", False):
# If the permission specifies read access for anonymous users, then
# the space is publicly accessible
is_externally_public = True
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_externally_public,
)
def _get_space_permissions(
confluence_client: OnyxConfluence,
is_cloud: bool,
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
while True:
spaces_batch = confluence_client.get_all_spaces(
start=start, limit=_REQUEST_PAGINATION_LIMIT
)
for space in spaces_batch.get("results", []):
all_space_keys.append(space.get("key"))
if len(spaces_batch.get("results", [])) < _REQUEST_PAGINATION_LIMIT:
break
start += len(spaces_batch.get("results", []))
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys:
if is_cloud:
space_permissions = _get_cloud_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
else:
space_permissions = _get_server_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"permissions for space '{space_key}'"
)
return space_permissions_by_space_key
def _extract_read_access_restrictions(
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
) -> tuple[set[str], set[str], bool]:
"""
Converts a page's restrictions dict into an ExternalAccess object.
If there are no restrictions, then return None
"""
read_access = restrictions.get("read", {})
read_access_restrictions = read_access.get("restrictions", {})
# Extract the users with read access
read_access_user = read_access_restrictions.get("user", {})
read_access_user_jsons = read_access_user.get("results", [])
# any items found means that there is a restriction
found_any_restriction = bool(read_access_user_jsons)
read_access_user_emails = []
for user in read_access_user_jsons:
# If the user has an email, then add it to the list
if user.get("email"):
read_access_user_emails.append(user["email"])
# If the user has a username and not an email, then get the email from Confluence
elif user.get("username"):
email = get_user_email_from_username__server(
confluence_client=confluence_client, user_name=user["username"]
)
if email:
read_access_user_emails.append(email)
else:
logger.warning(
f"Email for user {user['username']} not found in Confluence"
)
else:
if user.get("email") is not None:
logger.warning(f"Cant find email for user {user.get('displayName')}")
logger.warning(
"This user needs to make their email accessible in Confluence Settings"
)
logger.warning(f"no user email or username for {user}")
# Extract the groups with read access
read_access_group = read_access_restrictions.get("group", {})
read_access_group_jsons = read_access_group.get("results", [])
# any items found means that there is a restriction
found_any_restriction |= bool(read_access_group_jsons)
read_access_group_names = [
group["name"] for group in read_access_group_jsons if group.get("name")
]
return (
set(read_access_user_emails),
set(read_access_group_names),
found_any_restriction,
)
def _get_all_page_restrictions(
confluence_client: OnyxConfluence,
perm_sync_data: dict[str, Any],
) -> ExternalAccess | None:
"""
This function gets the restrictions for a page. In Confluence, a child can have
at MOST the same level accessibility as its immediate parent.
If no restrictions are found anywhere, then return None, indicating that the page
should inherit the space's restrictions.
"""
found_user_emails: set[str] = set()
found_group_names: set[str] = set()
# NOTE: need the found_any_restriction, since we can find restrictions
# but not be able to extract any user emails or group names
# in this case, we should just give no access
found_user_emails, found_group_names, found_any_page_level_restriction = (
_extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=perm_sync_data.get("restrictions", {}),
)
)
# if there are individual page-level restrictions, then this is the accurate
# restriction for the page. You cannot both have page-level restrictions AND
# inherit restrictions from the parent.
if found_any_page_level_restriction:
return ExternalAccess(
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
is_public=False,
)
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
# ancestors seem to be in order from root to immediate parent
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
# we want the restrictions from the immediate parent to take precedence, so we should
# reverse the list
for ancestor in reversed(ancestors):
(
ancestor_user_emails,
ancestor_group_names,
found_any_restrictions_in_ancestor,
) = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=ancestor.get("restrictions", {}),
)
if found_any_restrictions_in_ancestor:
# if inheriting restrictions from the parent, then the first one we run into
# should be applied (the reason why we'd traverse more than one ancestor is if
# the ancestor also is in "inherit" mode.)
logger.info(
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
f"for document {perm_sync_data.get('id')} based on ancestor {ancestor}"
)
return ExternalAccess(
external_user_emails=ancestor_user_emails,
external_user_group_ids=ancestor_group_names,
is_public=False,
)
# we didn't find any restrictions, so the page inherits the space's restrictions
return None
def _fetch_all_page_restrictions(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
)
if restrictions := _get_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
logger.info(f"Found restrictions {restrictions} for document {slim_doc.id}")
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
space_key = slim_doc.perm_sync_data.get("space_key")
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
logger.warning(
f"Individually fetching space permissions for space {space_key}. This is "
"unexpected. It means the permissions were not able to fetched initially."
)
try:
# If the space permissions are not in the cache, then fetch them
if is_cloud:
retrieved_space_permissions = _get_cloud_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
else:
retrieved_space_permissions = _get_server_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
space_permissions_by_space_key[space_key] = retrieved_space_permissions
space_permissions = retrieved_space_permissions
except Exception as e:
logger.warning(
f"Error fetching space permissions for space {space_key}: {e}"
)
if not space_permissions:
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
)
# be safe, if we can't get the permissions then make the document inaccessible
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
),
)
continue
# If there are no restrictions, then use the space's restrictions
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions may be wrong for"
f" Space key: {space_key}"
)
logger.info("Finished fetching all page restrictions")
def confluence_doc_sync(
@@ -32,6 +400,7 @@ def confluence_doc_sync(
Compares fetched documents against existing documents in the DB for the connector.
If a document exists in the DB but not in the Confluence fetch, it's marked as restricted.
"""
logger.info(f"Starting confluence doc sync for CC Pair ID: {cc_pair.id}")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
@@ -41,11 +410,62 @@ def confluence_doc_sync(
)
confluence_connector.set_credentials_provider(provider)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
callback=callback,
doc_source=DocumentSource.CONFLUENCE,
slim_connector=confluence_connector,
label=CONFLUENCE_DOC_SYNC_LABEL,
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
space_permissions_by_space_key = _get_space_permissions(
confluence_client=confluence_connector.confluence_client,
is_cloud=is_cloud,
)
logger.info("Space permissions by space key:")
for space_key, space_permissions in space_permissions_by_space_key.items():
logger.info(f"Space key: {space_key}, Permissions: {space_permissions}")
slim_docs: list[SlimDocument] = []
logger.info("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
logger.info(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
# Find documents that are no longer accessible in Confluence
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id}")
existing_doc_ids = fetch_all_existing_docs_fn()
# Find missing doc IDs
fetched_doc_ids = {doc.id for doc in slim_docs}
missing_doc_ids = set(existing_doc_ids) - fetched_doc_ids
# Yield access removal for missing docs. Better to be safe.
if missing_doc_ids:
logger.warning(
f"Found {len(missing_doc_ids)} documents that are in the DB but "
"not present in Confluence fetch. Making them inaccessible."
)
for missing_id in missing_doc_ids:
logger.warning(f"Removing access for document ID: {missing_id}")
yield DocExternalAccess(
doc_id=missing_id,
external_access=ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
),
)
logger.info("Fetching all page restrictions for fetched documents")
yield from _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)
logger.info("Finished confluence doc sync")

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
@@ -67,7 +65,7 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
) -> list[ExternalUserGroup]:
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
@@ -91,10 +89,10 @@ def confluence_group_sync(
confluence_client=confluence_client,
cc_pair_id=cc_pair.id,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()
for group_id, group_member_emails in group_member_email_map.items():
yield (
onyx_groups.append(
ExternalUserGroup(
id=group_id,
user_emails=list(group_member_emails),
@@ -109,4 +107,6 @@ def confluence_group_sync(
id=ALL_CONF_EMAILS_GROUP_NAME,
user_emails=list(all_found_emails),
)
yield all_found_group
onyx_groups.append(all_found_group)
return onyx_groups

View File

@@ -1,133 +0,0 @@
from typing import Any
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _extract_read_access_restrictions(
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
) -> tuple[set[str], set[str], bool]:
"""
Converts a page's restrictions dict into an ExternalAccess object.
If there are no restrictions, then return None
"""
read_access = restrictions.get("read", {})
read_access_restrictions = read_access.get("restrictions", {})
# Extract the users with read access
read_access_user = read_access_restrictions.get("user", {})
read_access_user_jsons = read_access_user.get("results", [])
# any items found means that there is a restriction
found_any_restriction = bool(read_access_user_jsons)
read_access_user_emails = []
for user in read_access_user_jsons:
# If the user has an email, then add it to the list
if user.get("email"):
read_access_user_emails.append(user["email"])
# If the user has a username and not an email, then get the email from Confluence
elif user.get("username"):
email = get_user_email_from_username__server(
confluence_client=confluence_client, user_name=user["username"]
)
if email:
read_access_user_emails.append(email)
else:
logger.warning(
f"Email for user {user['username']} not found in Confluence"
)
else:
if user.get("email") is not None:
logger.warning(f"Cant find email for user {user.get('displayName')}")
logger.warning(
"This user needs to make their email accessible in Confluence Settings"
)
logger.warning(f"no user email or username for {user}")
# Extract the groups with read access
read_access_group = read_access_restrictions.get("group", {})
read_access_group_jsons = read_access_group.get("results", [])
# any items found means that there is a restriction
found_any_restriction |= bool(read_access_group_jsons)
read_access_group_names = [
group["name"] for group in read_access_group_jsons if group.get("name")
]
return (
set(read_access_user_emails),
set(read_access_group_names),
found_any_restriction,
)
def get_page_restrictions(
confluence_client: OnyxConfluence,
page_id: str,
page_restrictions: dict[str, Any],
ancestors: list[dict[str, Any]],
) -> ExternalAccess | None:
"""
This function gets the restrictions for a page. In Confluence, a child can have
at MOST the same level accessibility as its immediate parent.
If no restrictions are found anywhere, then return None, indicating that the page
should inherit the space's restrictions.
"""
found_user_emails: set[str] = set()
found_group_names: set[str] = set()
# NOTE: need the found_any_restriction, since we can find restrictions
# but not be able to extract any user emails or group names
# in this case, we should just give no access
found_user_emails, found_group_names, found_any_page_level_restriction = (
_extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=page_restrictions,
)
)
# if there are individual page-level restrictions, then this is the accurate
# restriction for the page. You cannot both have page-level restrictions AND
# inherit restrictions from the parent.
if found_any_page_level_restriction:
return ExternalAccess(
external_user_emails=found_user_emails,
external_user_group_ids=found_group_names,
is_public=False,
)
# ancestors seem to be in order from root to immediate parent
# https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981
# we want the restrictions from the immediate parent to take precedence, so we should
# reverse the list
for ancestor in reversed(ancestors):
(
ancestor_user_emails,
ancestor_group_names,
found_any_restrictions_in_ancestor,
) = _extract_read_access_restrictions(
confluence_client=confluence_client,
restrictions=ancestor.get("restrictions", {}),
)
if found_any_restrictions_in_ancestor:
# if inheriting restrictions from the parent, then the first one we run into
# should be applied (the reason why we'd traverse more than one ancestor is if
# the ancestor also is in "inherit" mode.)
logger.debug(
f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}"
f"for document {page_id} based on ancestor {ancestor}"
)
return ExternalAccess(
external_user_emails=ancestor_user_emails,
external_user_group_ids=ancestor_group_names,
is_public=False,
)
# we didn't find any restrictions, so the page inherits the space's restrictions
return None

View File

@@ -1,165 +0,0 @@
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from ee.onyx.external_permissions.confluence.constants import REQUEST_PAGINATION_LIMIT
from ee.onyx.external_permissions.confluence.constants import VIEWSPACE_PERMISSION_TYPE
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
viewspace_permissions = []
for permission_category in space_permissions:
if permission_category.get("type") == VIEWSPACE_PERMISSION_TYPE:
viewspace_permissions.extend(
permission_category.get("spacePermissions", [])
)
is_public = False
user_names = set()
group_names = set()
for permission in viewspace_permissions:
if user_name := permission.get("userName"):
user_names.add(user_name)
if group_name := permission.get("groupName"):
group_names.add(group_name)
# It seems that if anonymous access is turned on for the site and space,
# then the space is publicly accessible.
# For confluence server, we make a group that contains all users
# that exist in confluence and then just add that group to the space permissions
# if anonymous access is turned on for the site and space or we set is_public = True
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
# that we can support confluence server deployments that want anonymous access
# to be public (we cant test this because its paywalled)
if user_name is None and group_name is None:
# Defaults to False
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
is_public = True
else:
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
user_emails = set()
for user_name in user_names:
user_email = get_user_email_from_username__server(confluence_client, user_name)
if user_email:
user_emails.add(user_email)
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_public,
)
def _get_cloud_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions_result = confluence_client.get_space(
space_key=space_key, expand="permissions"
)
space_permissions = space_permissions_result.get("permissions", [])
user_emails = set()
group_names = set()
is_externally_public = False
for permission in space_permissions:
subs = permission.get("subjects")
if subs:
# If there are subjects, then there are explicit users or groups with access
if email := subs.get("user", {}).get("results", [{}])[0].get("email"):
user_emails.add(email)
if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"):
group_names.add(group_name)
else:
# If there are no subjects, then the permission is for everyone
if permission.get("operation", {}).get(
"operation"
) == "read" and permission.get("anonymousAccess", False):
# If the permission specifies read access for anonymous users, then
# the space is publicly accessible
is_externally_public = True
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,
is_public=is_externally_public,
)
def get_space_permission(
confluence_client: OnyxConfluence,
space_key: str,
is_cloud: bool,
) -> ExternalAccess:
if is_cloud:
space_permissions = _get_cloud_space_permissions(confluence_client, space_key)
else:
space_permissions = _get_server_space_permissions(confluence_client, space_key)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"permissions for space '{space_key}'"
)
return space_permissions
def get_all_space_permissions(
confluence_client: OnyxConfluence,
is_cloud: bool,
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
while True:
spaces_batch = confluence_client.get_all_spaces(
start=start, limit=REQUEST_PAGINATION_LIMIT
)
for space in spaces_batch.get("results", []):
all_space_keys.append(space.get("key"))
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
break
start += len(spaces_batch.get("results", []))
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys:
space_permissions = get_space_permission(confluence_client, space_key, is_cloud)
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
return space_permissions_by_space_key

View File

@@ -4,6 +4,7 @@ from datetime import timezone
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
@@ -58,11 +59,17 @@ def gmail_doc_sync(
callback.progress("gmail_doc_sync", 1)
if slim_doc.external_access is None:
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=slim_doc.external_access,
)
if user_email := slim_doc.perm_sync_data.get("user_email"):
ext_access = ExternalAccess(
external_user_emails=set([user_email]),
external_user_group_ids=set(),
is_public=False,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)

View File

@@ -1,6 +1,11 @@
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.models import PermissionType
@@ -11,9 +16,10 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFuncti
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -40,34 +46,80 @@ def _get_slim_doc_generator(
)
def get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
def _drive_connector_creds_getter(
google_drive_connector: GoogleDriveConnector,
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
def inner() -> ServiceAccountCredentials | OAuthCredentials:
if not google_drive_connector._creds_dict:
raise ValueError(
"Creds dict not found, load_credentials must be called first"
)
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
return google_drive_connector.creds
Assumes the file we retrieved has EITHER `permissions` or `permission_ids`
"""
doc_id = file.get("id")
if not doc_id:
raise ValueError("No doc_id found in file")
return inner
permissions = file.get("permissions")
permission_ids = file.get("permissionIds")
drive_id = file.get("driveId")
permissions_list: list[GoogleDrivePermission] = []
if permissions:
permissions_list = [
GoogleDrivePermission.from_drive_permission(p) for p in permissions
]
elif permission_ids:
permissions_list = get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
def _fetch_permissions_for_permission_ids(
google_drive_connector: GoogleDriveConnector,
permission_info: dict[str, Any],
) -> list[GoogleDrivePermission]:
doc_id = permission_info.get("doc_id")
if not permission_info or not doc_id:
return []
owner_email = permission_info.get("owner_email")
permission_ids = permission_info.get("permission_ids", [])
if not permission_ids:
return []
if not owner_email:
logger.warning(
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
)
refreshable_drive_service = RefreshableDriveObject(
call_stack=lambda creds: get_drive_service(
creds=creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
),
creds=google_drive_connector.creds,
creds_getter=_drive_connector_creds_getter(google_drive_connector),
)
return get_permissions_by_ids(
drive_service=refreshable_drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
def _get_permissions_from_slim_doc(
google_drive_connector: GoogleDriveConnector,
slim_doc: SlimDocument,
) -> ExternalAccess:
permission_info = slim_doc.perm_sync_data or {}
permissions_list: list[GoogleDrivePermission] = []
raw_permissions_list = permission_info.get("permissions", [])
if not raw_permissions_list:
permissions_list = _fetch_permissions_for_permission_ids(
google_drive_connector=google_drive_connector,
permission_info=permission_info,
)
if not permissions_list:
logger.warning(f"No permissions found for document {slim_doc.id}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
)
else:
permissions_list = [
GoogleDrivePermission.from_drive_permission(p) for p in raw_permissions_list
]
company_domain = google_drive_connector.google_domain
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()
group_emails: set[str] = set()
@@ -92,7 +144,7 @@ def get_external_access_for_raw_gdrive_file(
else:
logger.error(
"Permission is type `user` but no email address is "
f"provided for document {doc_id}"
f"provided for document {slim_doc.id}"
f"\n {permission}"
)
elif permission.type == PermissionType.GROUP:
@@ -102,7 +154,7 @@ def get_external_access_for_raw_gdrive_file(
else:
logger.error(
"Permission is type `group` but no email address is "
f"provided for document {doc_id}"
f"provided for document {slim_doc.id}"
f"\n {permission}"
)
elif permission.type == PermissionType.DOMAIN and company_domain:
@@ -116,6 +168,7 @@ def get_external_access_for_raw_gdrive_file(
elif permission.type == PermissionType.ANYONE:
public = True
drive_id = permission_info.get("drive_id")
group_ids = (
group_emails
| folder_ids_to_inherit_permissions_from
@@ -157,13 +210,12 @@ def gdrive_doc_sync(
callback.progress("gdrive_doc_sync", 1)
if slim_doc.external_access is None:
raise ValueError(
f"Drive perm sync: No external access for document {slim_doc.id}"
)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
yield DocExternalAccess(
external_access=slim_doc.external_access,
external_access=ext_access,
doc_id=slim_doc.id,
)
total_processed += len(slim_doc_batch)

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from googleapiclient.errors import HttpError # type: ignore
from pydantic import BaseModel
@@ -101,44 +99,6 @@ def _get_all_folders(
return all_folders
def _drive_folder_to_onyx_group(
folder: FolderInfo,
group_email_to_member_emails_map: dict[str, list[str]],
) -> ExternalUserGroup:
"""
Converts a folder into an Onyx group.
"""
anyone_can_access = False
folder_member_emails: set[str] = set()
for permission in folder.permissions:
if permission.type == PermissionType.USER:
if permission.email_address is None:
logger.warning(
f"User email is None for folder {folder.id} permission {permission}"
)
continue
folder_member_emails.add(permission.email_address)
elif permission.type == PermissionType.GROUP:
if permission.email_address not in group_email_to_member_emails_map:
logger.warning(
f"Group email {permission.email_address} for folder {folder.id} "
"not found in group_email_to_member_emails_map"
)
continue
folder_member_emails.update(
group_email_to_member_emails_map[permission.email_address]
)
elif permission.type == PermissionType.ANYONE:
anyone_can_access = True
return ExternalUserGroup(
id=folder.id,
user_emails=list(folder_member_emails),
gives_anyone_access=anyone_can_access,
)
"""Individual Shared Drive / My Drive Permission Sync"""
@@ -207,29 +167,7 @@ def _get_drive_members(
return drive_id_to_members_map
def _drive_member_map_to_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, list[str]],
) -> Generator[ExternalUserGroup, None, None]:
"""The `user_emails` for the Shared Drive should be all individuals in the
Shared Drive + the union of all flattened group emails."""
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
drive_member_emails: set[str] = user_emails
for group_email in group_emails:
if group_email not in group_email_to_member_emails_map:
logger.warning(
f"Group email {group_email} for drive {drive_id} not found in "
"group_email_to_member_emails_map"
)
continue
drive_member_emails.update(group_email_to_member_emails_map[group_email])
yield ExternalUserGroup(
id=drive_id,
user_emails=list(drive_member_emails),
)
def _get_all_google_groups(
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
@@ -247,28 +185,6 @@ def _get_all_google_groups(
return group_emails
def _google_group_to_onyx_group(
admin_service: AdminService,
group_email: str,
) -> ExternalUserGroup:
"""
This maps google group emails to their member emails.
"""
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email),nextPageToken",
):
group_member_emails.add(member["email"])
return ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
@@ -366,7 +282,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
@@ -380,27 +296,26 @@ def gdrive_group_sync(
drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service)
# Get all group emails
all_group_emails = _get_all_google_groups(
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
# Each google group is an Onyx group, yield those
group_email_to_member_emails_map: dict[str, list[str]] = {}
for group_email in all_group_emails:
onyx_group = _google_group_to_onyx_group(admin_service, group_email)
group_email_to_member_emails_map[group_email] = onyx_group.user_emails
yield onyx_group
# Each drive is a group, yield those
for onyx_group in _drive_member_map_to_onyx_groups(
drive_id_to_members_map, group_email_to_member_emails_map
):
yield onyx_group
# Get all folder permissions
folder_info = _get_all_folders(
google_drive_connector=google_drive_connector,
skip_folders_without_permissions=True,
)
for folder in folder_info:
yield _drive_folder_to_onyx_group(folder, group_email_to_member_emails_map)
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
folder_info=folder_info,
)
return onyx_groups

View File

@@ -2,7 +2,7 @@ from retry import retry
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -10,7 +10,7 @@ logger = setup_logger()
@retry(tries=3, delay=2, backoff=2)
def get_permissions_by_ids(
drive_service: GoogleDriveService,
drive_service: RefreshableDriveObject,
doc_id: str,
permission_ids: list[str],
) -> list[GoogleDrivePermission]:

View File

@@ -1,34 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.jira.connector import JiraConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
JIRA_DOC_SYNC_TAG = "jira_doc_sync"
def jira_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
jira_connector = JiraConnector(
**cc_pair.connector.connector_specific_config,
)
jira_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
callback=callback,
doc_source=DocumentSource.JIRA,
slim_connector=jira_connector,
label=JIRA_DOC_SYNC_TAG,
)

View File

@@ -1,25 +0,0 @@
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic.alias_generators import to_camel
Holder = dict[str, Any]
class Permission(BaseModel):
id: int
permission: str
holder: Holder | None
class User(BaseModel):
account_id: str
email_address: str
display_name: str
active: bool
model_config = ConfigDict(
alias_generator=to_camel,
)

View File

@@ -1,209 +0,0 @@
from collections import defaultdict
from jira import JIRA
from jira.resources import PermissionScheme
from pydantic import ValidationError
from ee.onyx.external_permissions.jira.models import Holder
from ee.onyx.external_permissions.jira.models import Permission
from ee.onyx.external_permissions.jira.models import User
from onyx.access.models import ExternalAccess
from onyx.utils.logger import setup_logger
HolderMap = dict[str, list[Holder]]
logger = setup_logger()
def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
"""
A "Holder" in JIRA is a person / entity who "holds" the corresponding permission.
It can have different types. They can be one of (but not limited to):
- user (an explicitly whitelisted user)
- projectRole (for project level "roles")
- reporter (the reporter of an issue)
A "Holder" usually has following structure:
- `{ "type": "user", "value": "$USER_ID", "user": { .. }, .. }`
- `{ "type": "projectRole", "value": "$PROJECT_ID", .. }`
When we fetch the PermissionSchema from JIRA, we retrieve a list of "Holder"s.
The list of "Holder"s can have multiple "Holder"s of the same type in the list (e.g., you can have two `"type": "user"`s in
there, each corresponding to a different user).
This function constructs a map of "Holder" types to a list of the "Holder"s which contained that type.
Returns:
A dict from the "Holder" type to the actual "Holder" instance.
Example:
```
{
"user": [
{ "type": "user", "value": "10000", "user": { .. }, .. },
{ "type": "user", "value": "10001", "user": { .. }, .. },
],
"projectRole": [
{ "type": "projectRole", "value": "10010", .. },
{ "type": "projectRole", "value": "10011", .. },
],
"applicationRole": [
{ "type": "applicationRole" },
],
..
}
```
"""
holder_map: defaultdict[str, list[Holder]] = defaultdict(list)
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
# We only care about ability to browse through projects + issues (not other permissions such as read/write).
if permission.permission != "BROWSE_PROJECTS":
continue
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warn(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warn(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue
holder_map[type].append(permission.holder)
return holder_map
def _get_user_emails(user_holders: list[Holder]) -> list[str]:
emails = []
for user_holder in user_holders:
if "user" not in user_holder:
continue
raw_user_dict = user_holder["user"]
try:
user_model = User.model_validate(raw_user_dict)
except ValidationError:
logger.error(
"Expected to be able to serialize the raw-user-dict into an instance of `User`, but validation failed;"
f"{raw_user_dict=}"
)
continue
emails.append(user_model.email_address)
return emails
def _get_user_emails_from_project_roles(
jira_client: JIRA,
jira_project: str,
project_role_holders: list[Holder],
) -> list[str]:
# NOTE (@raunakab) a `parallel_yield` may be helpful here...?
roles = [
jira_client.project_role(project=jira_project, id=project_role_holder["value"])
for project_role_holder in project_role_holders
if "value" in project_role_holder
]
emails = []
for role in roles:
if not hasattr(role, "actors"):
continue
for actor in role.actors:
if not hasattr(actor, "actorUser") or not hasattr(
actor.actorUser, "accountId"
):
continue
user = jira_client.user(id=actor.actorUser.accountId)
if not hasattr(user, "accountType") or user.accountType != "atlassian":
continue
if not hasattr(user, "emailAddress"):
msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}"
if hasattr(user, "displayName"):
msg += f" {actor.displayName=}"
logger.warn(msg)
continue
emails.append(user.emailAddress)
return emails
def _build_external_access_from_holder_map(
jira_client: JIRA, jira_project: str, holder_map: HolderMap
) -> ExternalAccess:
"""
# Note:
If the `holder_map` contains an instance of "anyone", then this is a public JIRA project.
Otherwise, we fetch the "projectRole"s (i.e., the user-groups in JIRA speak), and the user emails.
"""
if "anyone" in holder_map:
return ExternalAccess(
external_user_emails=set(), external_user_group_ids=set(), is_public=True
)
user_emails = (
_get_user_emails(user_holders=holder_map["user"])
if "user" in holder_map
else []
)
project_role_user_emails = (
_get_user_emails_from_project_roles(
jira_client=jira_client,
jira_project=jira_project,
project_role_holders=holder_map["projectRole"],
)
if "projectRole" in holder_map
else []
)
external_user_emails = set(user_emails + project_role_user_emails)
return ExternalAccess(
external_user_emails=external_user_emails,
external_user_group_ids=set(),
is_public=False,
)
def get_project_permissions(
jira_client: JIRA,
jira_project: str,
) -> ExternalAccess | None:
project_permissions: PermissionScheme = jira_client.project_permissionscheme(
project=jira_project
)
if not hasattr(project_permissions, "permissions"):
return None
if not isinstance(project_permissions.permissions, list):
return None
holder_map = _build_holder_map(permissions=project_permissions.permissions)
return _build_external_access_from_holder_map(
jira_client=jira_client, jira_project=jira_project, holder_map=holder_map
)

View File

@@ -4,8 +4,6 @@ from typing import Optional
from typing import Protocol
from typing import TYPE_CHECKING
from onyx.context.search.models import InferenceChunk
# Avoid circular imports
if TYPE_CHECKING:
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
@@ -39,11 +37,8 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str, # tenant_id
"ConnectorCredentialPair", # cc_pair
str,
"ConnectorCredentialPair",
],
Generator["ExternalUserGroup", None, None],
list["ExternalUserGroup"],
]
# list of chunks to be censored and the user email. returns censored chunks
CensoringFuncType = Callable[[list[InferenceChunk], str], list[InferenceChunk]]

View File

@@ -1,33 +1,43 @@
from collections.abc import Callable
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_sources
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source has a censoring config.
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
all_censoring_enabled_sources = get_all_censoring_enabled_sources()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in all_censoring_enabled_sources
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}
@@ -60,11 +70,7 @@ def _post_query_chunk_censoring(
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
sync_config = get_source_perm_sync_config(source)
if sync_config is None or sync_config.censoring_config is None:
raise ValueError(f"No sync config found for {source}")
censor_chunks_for_source = sync_config.censoring_config.chunk_censoring_func
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:

View File

@@ -10,7 +10,7 @@ from ee.onyx.external_permissions.salesforce.utils import (
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -44,7 +44,7 @@ def _get_objects_access_for_user_email_from_salesforce(
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
@@ -217,7 +217,7 @@ def censor_salesforce_chunks(
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,

View File

@@ -1,63 +0,0 @@
from slack_sdk import WebClient
from onyx.access.models import ExternalAccess
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.slack.connector import ChannelType
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_paginated_slack_api_call
def get_channel_access(
client: WebClient,
channel: ChannelType,
user_cache: dict[str, BasicExpertInfo | None],
) -> ExternalAccess:
"""
Get channel access permissions for a Slack channel.
Args:
client: Slack WebClient instance
channel: Slack channel object containing channel info
user_cache: Cache of user IDs to BasicExpertInfo objects. May be updated in place.
Returns:
ExternalAccess object for the channel.
"""
channel_is_public = not channel["is_private"]
if channel_is_public:
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
channel_id = channel["id"]
# Get all member IDs for the channel
member_ids = []
for result in make_paginated_slack_api_call(
client.conversations_members,
channel=channel_id,
):
member_ids.extend(result.get("members", []))
member_emails = set()
for member_id in member_ids:
# Try to get user info from cache or fetch it
user_info = expert_info_from_slack_id(
user_id=member_id,
client=client,
user_cache=user_cache,
)
# If we have user info and an email, add it to the set
if user_info and user_info.email:
member_emails.add(user_info.email)
return ExternalAccess(
external_user_emails=member_emails,
# NOTE: groups are not used, since adding a group to a channel just adds all
# users that are in the group.
external_user_group_ids=set(),
is_public=False,
)

View File

@@ -108,15 +108,11 @@ def _get_slack_document_access(
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.external_access is None:
raise ValueError(
f"No external access for document {doc_metadata.id}. "
"Please check to make sure that your Slack bot token has the "
"`channels:read` scope"
)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
yield DocExternalAccess(
external_access=doc_metadata.external_access,
external_access=channel_permissions[channel_id],
doc_id=doc_metadata.id,
)

View File

@@ -58,8 +58,6 @@ def slack_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
"""NOTE: not used atm. All channel access is done at the
individual user level. Leaving in for now in case we need it later."""
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)

View File

@@ -1,207 +1,83 @@
from collections.abc import Generator
from typing import Optional
from typing import TYPE_CHECKING
from pydantic import BaseModel
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
from ee.onyx.external_permissions.slack.group_sync import slack_group_sync
from onyx.configs.constants import DocumentSource
if TYPE_CHECKING:
from onyx.access.models import DocExternalAccess # noqa
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
class DocSyncConfig(BaseModel):
doc_sync_frequency: int
doc_sync_func: DocSyncFuncType
initial_index_should_sync: bool
class GroupSyncConfig(BaseModel):
group_sync_frequency: int
group_sync_func: GroupSyncFuncType
group_sync_is_cc_pair_agnostic: bool
class CensoringConfig(BaseModel):
chunk_censoring_func: CensoringFuncType
class SyncConfig(BaseModel):
# None means we don't perform a doc_sync
doc_sync_config: DocSyncConfig | None = None
# None means we don't perform a group_sync
group_sync_config: GroupSyncConfig | None = None
# None means we don't perform a chunk_censoring
censoring_config: CensoringConfig | None = None
# Mock doc sync function for testing (no-op)
def mock_doc_sync(
cc_pair: "ConnectorCredentialPair",
fetch_all_docs_fn: FetchAllDocumentsFunction,
callback: Optional["IndexingHeartbeatInterface"],
) -> Generator["DocExternalAccess", None, None]:
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
yield from []
_SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
DocumentSource.GOOGLE_DRIVE: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=gdrive_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=gdrive_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
DocumentSource.CONFLUENCE: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=confluence_doc_sync,
initial_index_should_sync=False,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=confluence_group_sync,
group_sync_is_cc_pair_agnostic=True,
),
),
DocumentSource.JIRA: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=JIRA_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=jira_doc_sync,
initial_index_should_sync=True,
),
),
# Groups are not needed for Slack.
# All channel access is done at the individual user level.
DocumentSource.SLACK: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=slack_doc_sync,
initial_index_should_sync=True,
),
),
DocumentSource.GMAIL: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=gmail_doc_sync,
initial_index_should_sync=False,
),
),
DocumentSource.SALESFORCE: SyncConfig(
censoring_config=CensoringConfig(
chunk_censoring_func=censor_salesforce_chunks,
),
),
DocumentSource.MOCK_CONNECTOR: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=mock_doc_sync,
initial_index_should_sync=True,
),
),
# Groups are not needed for Teams.
# All channel access is done at the individual user level.
DocumentSource.TEAMS: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=TEAMS_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=teams_doc_sync,
initial_index_should_sync=True,
),
),
# These functions update:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
DocumentSource.SLACK: slack_doc_sync,
DocumentSource.GMAIL: gmail_doc_sync,
}
def source_requires_doc_sync(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires doc syncing."""
if source not in _SOURCE_TO_SYNC_CONFIG:
return False
return _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config is not None
return source in DOC_PERMISSIONS_FUNC_MAP
# These functions update:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
DocumentSource.SLACK: slack_group_sync,
}
def source_requires_external_group_sync(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires external group syncing."""
if source not in _SOURCE_TO_SYNC_CONFIG:
return False
return _SOURCE_TO_SYNC_CONFIG[source].group_sync_config is not None
return source in GROUP_PERMISSIONS_FUNC_MAP
def get_source_perm_sync_config(source: DocumentSource) -> SyncConfig | None:
"""Returns the frequency of the external group sync for the given DocumentSource."""
return _SOURCE_TO_SYNC_CONFIG.get(source)
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DocumentSource.CONFLUENCE,
}
def source_group_sync_is_cc_pair_agnostic(source: DocumentSource) -> bool:
"""Checks if the given DocumentSource requires external group syncing."""
if source not in _SOURCE_TO_SYNC_CONFIG:
return False
group_sync_config = _SOURCE_TO_SYNC_CONFIG[source].group_sync_config
if group_sync_config is None:
return False
return group_sync_config.group_sync_is_cc_pair_agnostic
return source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
def get_all_cc_pair_agnostic_group_sync_sources() -> set[DocumentSource]:
"""Returns the set of sources that have external group syncing that is cc_pair agnostic."""
return {
source
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
if sync_config.group_sync_config is not None
and sync_config.group_sync_config.group_sync_is_cc_pair_agnostic
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in _SOURCE_TO_SYNC_CONFIG
def get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""Returns the set of sources that have censoring enabled."""
return {
source
for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items()
if sync_config.censoring_config is not None
}
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
"""Returns True if the given DocumentSource requires permissions to be fetched during indexing."""
if source not in _SOURCE_TO_SYNC_CONFIG:
return False
doc_sync_config = _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config
if doc_sync_config is None:
return False
return doc_sync_config.initial_index_should_sync
return (
source_type in DOC_PERMISSIONS_FUNC_MAP
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
)

View File

@@ -1,35 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.teams.connector import TeamsConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
def teams_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
teams_connector = TeamsConnector(
**cc_pair.connector.connector_specific_config,
)
teams_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
callback=callback,
doc_source=DocumentSource.TEAMS,
slim_connector=teams_connector,
label=TEAMS_DOC_SYNC_LABEL,
)

View File

@@ -1,83 +0,0 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generic_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
A convenience function for performing a generic document synchronization.
Notes:
A generic doc sync includes:
- fetching existing docs
- fetching *all* new (slim) docs
- yielding external-access permissions for existing docs which do not exist in the newly fetched slim-docs set (with their
`external_access` set to "private")
- yielding external-access permissions for newly fetched docs
Returns:
A `Generator` which yields existing and newly fetched external-access permissions.
"""
logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}")
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:
if callback.should_stop():
raise RuntimeError(f"{label}: Stop signal detected")
callback.progress(label, 1)
for doc in doc_batch:
if not doc.external_access:
raise RuntimeError(
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
)
newly_fetched_doc_ids.add(doc.id)
yield DocExternalAccess(
doc_id=doc.id,
external_access=doc.external_access,
)
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
existing_doc_ids = set(fetch_all_existing_docs_fn())
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
if not missing_doc_ids:
return
logger.warning(
f"Found {len(missing_doc_ids)=} documents that are in the DB but not present in fetch. Making them inaccessible."
)
for missing_id in missing_doc_ids:
logger.warning(f"Removing access for {missing_id=}")
yield DocExternalAccess(
doc_id=missing_id,
external_access=ExternalAccess.empty(),
)
logger.info(f"Finished {doc_source} doc sync")

View File

@@ -19,7 +19,7 @@ from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
router = APIRouter(prefix="/analytics")

View File

@@ -17,7 +17,7 @@ from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
)
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client

View File

@@ -26,9 +26,9 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -142,12 +142,11 @@ def put_logo(
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
logger.exception("Faield to fetch logo file")
raise HTTPException(
status_code=404,
detail="No logo file found",
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -131,11 +131,11 @@ def upload_logo(
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
content=content,
display_name=display_name,
file_origin=FileOrigin.OTHER,
file_type=file_type,
file_id=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME,
)
return True

View File

@@ -13,7 +13,7 @@ from ee.onyx.db.standard_answer import remove_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer
from ee.onyx.db.standard_answer import update_standard_answer_category
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.server.manage.models import StandardAnswer
from onyx.server.manage.models import StandardAnswerCategory

View File

@@ -11,7 +11,7 @@ from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

View File

@@ -12,10 +12,10 @@ from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -25,12 +25,12 @@ from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@@ -33,11 +33,11 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class GoogleDriveOAuth:

View File

@@ -17,11 +17,11 @@ from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from shared_configs.contextvars import get_current_tenant_id
class SlackOAuth:

View File

@@ -40,7 +40,7 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer

View File

@@ -31,7 +31,7 @@ from onyx.context.search.utils import dedupe_documents
from onyx.context.search.utils import drop_llm_indices
from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import get_prompt_by_id
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit

View File

@@ -1,4 +1,3 @@
import uuid
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -12,7 +11,6 @@ from fastapi import Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ee.onyx.background.task_name_builders import query_history_task_name
from ee.onyx.db.query_history import get_all_query_history_export_tasks
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
@@ -37,19 +35,17 @@ from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.enums import TaskStatus
from onyx.db.file_record import get_query_history_export_files
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.db.pg_file_store import get_query_history_export_files
from onyx.db.tasks import get_task_with_id
from onyx.db.tasks import register_task
from onyx.file_store.file_store import get_default_file_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.utils.threadpool_concurrency import parallel_yield
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter()
@@ -314,32 +310,17 @@ def start_query_history_export(
f"Start time must come before end time, but instead got the start time coming after; {start=} {end=}",
)
task_id_uuid = uuid.uuid4()
task_id = str(task_id_uuid)
start_time = datetime.now(tz=timezone.utc)
register_task(
db_session=db_session,
task_name=query_history_task_name(start=start, end=end),
task_id=task_id,
status=TaskStatus.PENDING,
start_time=start_time,
)
client_app.send_task(
task = client_app.send_task(
OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
task_id=task_id,
priority=OnyxCeleryPriority.MEDIUM,
queue=OnyxCeleryQueues.CSV_GENERATION,
kwargs={
"start": start,
"end": end,
"start_time": start_time,
"tenant_id": get_current_tenant_id(),
},
)
return {"request_id": task_id}
return {"request_id": task.id}
@router.get("/admin/query-history/export-status")
@@ -362,7 +343,7 @@ def get_query_history_export_status(
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
file_id=report_name,
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
)
@@ -387,7 +368,7 @@ def download_query_history_csv(
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
has_file = file_store.has_file(
file_id=report_name,
file_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import SessionType
from onyx.db.enums import TaskStatus
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import FileRecord
from onyx.db.models import PGFileStore
from onyx.db.models import TaskQueueState
@@ -254,7 +254,7 @@ class QueryHistoryExport(BaseModel):
@classmethod
def from_file(
cls,
file: FileRecord,
file: PGFileStore,
) -> "QueryHistoryExport":
if not file.file_metadata or not isinstance(file.file_metadata, dict):
raise RuntimeError(
@@ -262,7 +262,7 @@ class QueryHistoryExport(BaseModel):
)
metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata))
task_id = extract_task_id_from_query_history_report_name(file.file_id)
task_id = extract_task_id_from_query_history_report_name(file.file_name)
return cls(
task_id=task_id,

View File

@@ -14,7 +14,7 @@ from ee.onyx.db.usage_export import get_usage_report_data
from ee.onyx.db.usage_export import UsageReportMetadata
from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.constants import STANDARD_CHUNK_SIZE

View File

@@ -62,16 +62,17 @@ def generate_chat_messages_report(
]
)
# after writing seek to beginning of buffer
# after writing seek to begining of buffer
temp_file.seek(0)
file_id = file_store.save_file(
file_store.save_file(
file_name=file_name,
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_type="text/csv",
)
return file_id
return file_name
def generate_user_report(
@@ -96,14 +97,15 @@ def generate_user_report(
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
temp_file.seek(0)
file_id = file_store.save_file(
file_store.save_file(
file_name=file_name,
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_type="text/csv",
)
return file_id
return file_name
def create_new_usage_report(
@@ -114,16 +116,16 @@ def create_new_usage_report(
report_id = str(uuid.uuid4())
file_store = get_default_file_store(db_session)
messages_file_id = generate_chat_messages_report(
messages_filename = generate_chat_messages_report(
db_session, file_store, report_id, period
)
users_file_id = generate_user_report(db_session, file_store, report_id)
users_filename = generate_user_report(db_session, file_store, report_id)
with tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) as zip_buffer:
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file:
# write messages
chat_messages_tmpfile = file_store.read_file(
messages_file_id, mode="b", use_tempfile=True
messages_filename, mode="b", use_tempfile=True
)
zip_file.writestr(
"chat_messages.csv",
@@ -132,7 +134,7 @@ def create_new_usage_report(
# write users
users_tmpfile = file_store.read_file(
users_file_id, mode="b", use_tempfile=True
users_filename, mode="b", use_tempfile=True
)
zip_file.writestr("users.csv", users_tmpfile.read())
@@ -144,11 +146,11 @@ def create_new_usage_report(
f"_{report_id}_usage_report.zip"
)
file_store.save_file(
file_name=report_name,
content=zip_buffer,
display_name=report_name,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="application/zip",
file_id=report_name,
)
# add report after zip file is written

View File

@@ -27,9 +27,9 @@ from onyx.auth.users import get_user_manager
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_context_manager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.utils.logger import setup_logger

View File

@@ -19,7 +19,7 @@ from ee.onyx.server.enterprise_settings.store import (
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -235,7 +235,7 @@ def seed_db() -> None:
logger.debug("No seeding configuration file passed")
return
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
if seed_config.llms is not None:
_seed_llms(db_session, seed_config.llms)
if seed_config.personas is not None:

View File

@@ -10,7 +10,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger

View File

@@ -18,7 +18,7 @@ from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id

View File

@@ -28,8 +28,8 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider

View File

@@ -8,8 +8,8 @@ from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
@@ -34,7 +34,7 @@ def run_alembic_migrations(schema_name: str) -> None:
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")

View File

@@ -9,7 +9,7 @@ from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail

View File

@@ -5,8 +5,8 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.utils.logger import setup_logger

View File

@@ -9,7 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_user_token_rate_limits
from onyx.db.token_limit import insert_user_token_rate_limit

View File

@@ -16,7 +16,7 @@ from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger

View File

@@ -1,16 +1,11 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.document import get_access_info_for_document
from onyx.db.document import get_access_info_for_documents
from onyx.db.models import User
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -112,15 +107,3 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
"onyx.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session) # type: ignore
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
_source_should_fetch_permissions_during_indexing_func = cast(
Callable[[DocumentSource], bool],
fetch_ee_implementation_or_noop(
"onyx.external_permissions.sync_params",
"source_should_fetch_permissions_during_indexing",
False,
),
)
return _source_should_fetch_permissions_during_indexing_func(source)

View File

@@ -40,30 +40,6 @@ class ExternalAccess:
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)
@classmethod
def public(cls) -> "ExternalAccess":
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
@classmethod
def empty(cls) -> "ExternalAccess":
"""
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
This effectively makes the document in question "private" or inaccessible to anyone else.
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
"""
return cls(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=False,
)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -78,7 +78,7 @@ def should_continue(state: BasicState) -> str:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -87,7 +87,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,

View File

@@ -4,7 +4,7 @@ from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,

View File

@@ -111,7 +111,7 @@ def answer_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -121,7 +121,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -238,7 +238,7 @@ def agent_search_graph_builder() -> StateGraph:
if __name__ == "__main__":
pass
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -246,7 +246,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request

View File

@@ -109,7 +109,7 @@ def answer_refined_query_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -119,7 +119,7 @@ if __name__ == "__main__":
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",

View File

@@ -131,7 +131,7 @@ def expanded_retrieval_graph_builder() -> StateGraph:
if __name__ == "__main__":
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
@@ -142,7 +142,7 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)

View File

@@ -24,7 +24,7 @@ from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from onyx.utils.timing import log_function_time
@@ -60,7 +60,7 @@ def rerank_documents(
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)

View File

@@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
@@ -67,7 +67,7 @@ def retrieve_documents(
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as db_session:
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
override_kwargs=SearchToolOverrideKwargs(

View File

@@ -6,13 +6,8 @@ from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import KGSourceDivisionType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.configs.kg_configs import KG_MAX_DECOMPOSITION_SEGMENTS
from onyx.utils.logger import setup_logger
logger = setup_logger()
class KGAnalysisPath(str, Enum):
@@ -49,17 +44,6 @@ def research_individual_object(
and state.strategy == KGAnswerStrategy.DEEP
):
if state.source_filters and state.source_division:
segments = state.source_filters
segment_type = KGSourceDivisionType.SOURCE.value
else:
segments = state.div_con_entities
segment_type = KGSourceDivisionType.ENTITY.value
if segments and (len(segments) > KG_MAX_DECOMPOSITION_SEGMENTS):
logger.debug(f"Too many sources ({len(segments)}), usingfiltered search")
return "filtered_search"
return [
Send(
"process_individual_deep_search",
@@ -70,14 +54,13 @@ def research_individual_object(
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
source_entity_filters=state.source_filters,
segment_type=segment_type,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
step_results=[],
),
)
for research_nr, entity in enumerate(segments)
for research_nr, entity in enumerate(state.div_con_entities)
]
elif state.search_type == KGSearchType.SEARCH:
return "filtered_search"

View File

@@ -19,7 +19,7 @@ from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entities import get_entity_name
from onyx.db.entity_type import get_entity_types
@@ -217,6 +217,31 @@ def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.MAIN_ANSWER,
level=level,
level_question_num=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_main_answer_token(
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
) -> None:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=token, # No need to add space as tokenizer handles this
level=level,
level_question_num=level_question_num,
answer_type="agent_level_answer",
),
writer,
)
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
"""
Get document information for an entity, including its semantic name and document details.
@@ -267,9 +292,11 @@ def rename_entities_in_answer(answer: str) -> str:
# Clean up any spaces around ::
answer = re.sub(r"::\s+", "::", answer)
logger.debug(f"After cleaning spaces: {answer}")
# Pattern to match entity_type::entity_name, with optional quotes
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
logger.debug(f"Using pattern: {pattern}")
matches = list(re.finditer(pattern, answer))
logger.debug(f"Found {len(matches)} matches")
@@ -288,8 +315,10 @@ def rename_entities_in_answer(answer: str) -> str:
entity_type = match.group(1).upper().strip()
entity_name = match.group(2).strip()
potential_entity_id_name = make_entity_id(entity_type, entity_name)
logger.debug(f"Processing entity: {potential_entity_id_name}")
if entity_type not in active_entity_types:
logger.debug(f"Entity type {entity_type} not in active types")
continue
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
@@ -299,8 +328,14 @@ def rename_entities_in_answer(answer: str) -> str:
processed_refs[match.group(0)] = (
replacement_candidate.semantic_linked_entity_name
)
logger.debug(
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_linked_entity_name}"
)
else:
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
logger.debug(
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_entity_name}"
)
# Replace all references in the answer
for ref, replacement in processed_refs.items():

View File

@@ -2,26 +2,19 @@ from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGViewNames(BaseModel):
allowed_docs_view_name: str
kg_relationships_view_name: str
kg_entity_view_name: str
class KGAnswerApproach(BaseModel):
search_type: KGSearchType
search_strategy: KGAnswerStrategy
relationship_detection: KGRelationshipDetection
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
@@ -34,6 +27,7 @@ class KGQuestionRelationshipExtractionResult(BaseModel):
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None

View File

@@ -4,7 +4,6 @@ from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from pydantic import ValidationError
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
@@ -25,17 +24,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import get_user_view_names
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.utils.extraction_utils import get_entity_types_str
from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.kg.extractions.extraction_processing import get_entity_types_str
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -81,16 +79,13 @@ def extract_ert(
stream_write_step_activities(writer, _KG_STEP_NR)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
tenant_id = get_current_tenant_id()
kg_views = get_user_view_names(user_email, tenant_id)
allowed_docs_view_name, kg_relationships_view_name = get_user_view_names(user_email)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
tenant_id=tenant_id,
user_email=user_email,
allowed_docs_view_name=kg_views.allowed_docs_view_name,
kg_relationships_view_name=kg_views.kg_relationships_view_name,
kg_entity_view_name=kg_views.kg_entity_view_name,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
### get the entities, terms, and filters
@@ -136,18 +131,25 @@ def extract_ert(
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
entity_extraction_result = KGQuestionEntityExtractionResult.model_validate_json(
cleaned_response
)
except ValidationError:
logger.error("Failed to parse LLM response as JSON in Entity Extraction")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
)
try:
entity_extraction_result = (
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[], time_filter=""
entities=[],
terms=[],
time_filter="",
)
# remove the attribute filters from the entities to for the purpose of the relationship
@@ -214,9 +216,9 @@ def extract_ert(
cleaned_response
)
)
except ValidationError:
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Relationship Extraction"
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
@@ -251,10 +253,10 @@ Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_
extracted_entities_w_attributes=entity_extraction_result.entities,
extracted_entities_no_attributes=entities_no_attributes,
extracted_relationships=relationship_extraction_result.relationships,
extracted_terms=entity_extraction_result.terms,
time_filter=entity_extraction_result.time_filter,
kg_doc_temp_view_name=kg_views.allowed_docs_view_name,
kg_rel_temp_view_name=kg_views.kg_relationships_view_name,
kg_entity_temp_view_name=kg_views.kg_entity_view_name,
kg_doc_temp_view_name=allowed_docs_view_name,
kg_rel_temp_view_name=kg_relationships_view_name,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -18,7 +18,6 @@ from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
@@ -27,10 +26,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.clustering.normalizations import normalize_terms
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
@@ -146,6 +147,7 @@ def analyze(
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
relationships = state.extracted_relationships
terms = state.extracted_terms
time_filter = state.time_filter
## STEP 2 - stream out goals
@@ -155,14 +157,18 @@ def analyze(
# Continue with node
normalized_entities = normalize_entities(
entities,
entities, allowed_docs_temp_view_name=state.kg_doc_temp_view_name
)
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
state.extracted_entities_w_attributes,
allowed_docs_temp_view_name=state.kg_doc_temp_view_name,
normalized_entities.entity_normalization_map,
)
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_terms = normalize_terms(terms)
normalized_time_filter = time_filter
# If single-doc inquiry, send to single-doc processing directly
@@ -231,9 +237,6 @@ def analyze(
)
search_type = approach_extraction_result.search_type
search_strategy = approach_extraction_result.search_strategy
relationship_detection = (
approach_extraction_result.relationship_detection.value
)
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
@@ -243,7 +246,6 @@ def analyze(
)
search_type = KGSearchType.SEARCH
search_strategy = KGAnswerStrategy.DEEP
relationship_detection = KGRelationshipDetection.RELATIONSHIPS.value
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
@@ -264,19 +266,6 @@ def analyze(
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
Format: {output_format.value}, Broken down question: {broken_down_question}"
extraction_detected_relationships = len(query_graph_relationships) > 0
if extraction_detected_relationships:
query_type = KGRelationshipDetection.RELATIONSHIPS.value
if extraction_detected_relationships:
logger.warning(
"Fyi - Extraction detected relationships: "
f"{extraction_detected_relationships}, "
f"but relationship detection: {relationship_detection}"
)
else:
query_type = KGRelationshipDetection.NO_RELATIONSHIPS.value
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
stream_close_step_answer(writer, _KG_STEP_NR)
@@ -289,9 +278,9 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
query_graph_entities_no_attributes=query_graph_entities,
query_graph_entities_w_attributes=normalized_entities.entities_w_attributes,
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
query_graph_relationships=query_graph_relationships,
normalized_terms=[], # TODO: remove fully later
normalized_terms=normalized_terms.terms,
normalized_time_filter=normalized_time_filter,
strategy=search_strategy,
broken_down_question=broken_down_question,
@@ -299,7 +288,6 @@ Format: {output_format.value}, Broken down question: {broken_down_question}"
divide_and_conquer=divide_and_conquer,
single_doc_id=single_doc_id,
search_type=search_type,
query_type=query_type,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -14,7 +14,6 @@ from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGRelationshipDetection
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
@@ -22,18 +21,13 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from shared_configs.contextvars import get_current_tenant_id
from onyx.configs.kg_configs import KG_MAX_DEEP_SEARCH_RESULTS
from onyx.configs.kg_configs import KG_SQL_GENERATION_MAX_TOKENS
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT_OVERRIDE
from onyx.configs.kg_configs import KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX
from onyx.configs.kg_configs import KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX
from onyx.db.engine.sql_engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import ENTITY_SOURCE_DETECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_ENTITY_SQL_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
@@ -44,65 +38,16 @@ from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _raise_error_if_sql_fails_problem_test(
sql_statement: str, relationship_view_name: str, entity_view_name: str | None
) -> bool:
"""
Check if the SQL statement is valid.
"""
if entity_view_name is None:
raise ValueError("entity_view_name is not set for sql_statement")
authorized_user_specification = relationship_view_name[
len(KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX) :
]
# remove the proper relationship and entity view names
base_sql_statement = sql_statement.replace(relationship_view_name, " ").replace(
entity_view_name, " "
)
# check whether other non-authorized relationship or entity viewnames are in sql_statement
if any(
view_name in base_sql_statement
for view_name in [
KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX,
KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX,
KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX,
]
):
raise ValueError(
f"SQL statement would attempt to access unauthorized views: {sql_statement} for \
user specification {authorized_user_specification}"
def _drop_temp_views(
allowed_docs_view_name: str, kg_relationships_view_name: str
) -> None:
with get_session_with_current_tenant() as db_session:
drop_views(
db_session,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
# check whether the sql statement would attempt to do unauthorized operations
# (the restrivtive db priviledges would preclude this anyway, but we check for safety and reporting)
UNAUTHORIZED_SQL_OPERATIONS = [
"INSERT ",
"UPDATE ",
"DELETE ",
"CREATE ",
"DROP ",
"ALTER ",
"TRUNCATE ",
"RENAME ",
"GRANT ",
"REVOKE ",
"DENY ",
]
if any(
operation.upper() in base_sql_statement.upper()
for operation in UNAUTHORIZED_SQL_OPERATIONS
):
raise ValueError(
f"SQL statement would attempt to do unauthorized operations: {sql_statement}"
)
return True
def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> str:
"""
@@ -118,14 +63,13 @@ def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> s
def _sql_is_aggregate_query(sql_statement: str) -> bool:
return any(
agg_func in sql_statement.upper()
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM(", "GROUP BY"]
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
)
def _get_source_documents(
sql_statement: str,
llm: LLM,
focus: str | None,
allowed_docs_view_name: str,
kg_relationships_view_name: str,
) -> str | None:
@@ -133,13 +77,7 @@ def _get_source_documents(
Generate SQL to retrieve source documents based on the input sql statement.
"""
base_prompt = (
SOURCE_DETECTION_PROMPT
if (focus == KGRelationshipDetection.RELATIONSHIPS.value or focus is None)
else ENTITY_SOURCE_DETECTION_PROMPT
)
source_detection_prompt = base_prompt.replace(
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
"---original_sql_statement---", sql_statement
)
@@ -155,8 +93,8 @@ def _get_source_documents(
KG_SQL_GENERATION_TIMEOUT,
llm.invoke,
prompt=msg,
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
timeout_override=25,
max_tokens=1200,
)
cleaned_response = (
@@ -166,11 +104,12 @@ def _get_source_documents(
sql_statement = sql_statement.split("</sql>")[0].strip()
except Exception as e:
error_msg = f"Could not generate source documents SQL: {e}"
if cleaned_response:
error_msg += f". Original model response: {cleaned_response}"
logger.error(error_msg)
if cleaned_response is not None:
logger.error(
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
)
else:
logger.error(f"Could not generate source documents SQL: {e}")
return None
@@ -200,9 +139,6 @@ def generate_simple_sql(
if state.kg_rel_temp_view_name is None:
raise ValueError("kg_rel_temp_view_name is not set")
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
@@ -251,41 +187,30 @@ def generate_simple_sql(
# First, create string of contextualized entities to avoid the model not
# being aware of what eg ACCOUNT::SF_8254Hs means as a normalized entity
# TODO: restructure with broader node rework
entity_explanation_str = _build_entity_explanation_str(
state.entity_normalization_map
)
doc_temp_view = state.kg_doc_temp_view_name
rel_temp_view = state.kg_rel_temp_view_name
ent_temp_view = state.kg_entity_temp_view_name
current_tenant = get_current_tenant_id()
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
simple_sql_prompt = (
SIMPLE_ENTITY_SQL_PROMPT.replace(
"---entity_types---", entities_types_str
)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
current_tenant_view_name = f'"{current_tenant}".{state.kg_doc_temp_view_name}'
current_tenant_rel_view_name = f'"{current_tenant}".{state.kg_rel_temp_view_name}'
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
else:
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
.replace(
"---query_relationships---",
"\n".join(state.query_graph_relationships),
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
.replace(
"---query_relationships---", "\n".join(state.query_graph_relationships)
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
# prepare SQL query generation
@@ -302,8 +227,8 @@ def generate_simple_sql(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=KG_SQL_GENERATION_TIMEOUT_OVERRIDE,
max_tokens=KG_SQL_GENERATION_MAX_TOKENS,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
@@ -314,8 +239,9 @@ def generate_simple_sql(
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
sql_statement = sql_statement.replace("relationship_table", rel_temp_view)
sql_statement = sql_statement.replace("entity_table", ent_temp_view)
sql_statement = sql_statement.replace(
"kg_relationship", current_tenant_rel_view_name
)
reasoning = (
cleaned_response.split("<reasoning>")[1]
@@ -324,101 +250,75 @@ def generate_simple_sql(
)
except Exception as e:
# TODO: restructure with broader node rework
logger.error(f"Error in SQL generation: {e}")
logger.error(f"Error in strategy generation: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
_drop_temp_views(
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
raise e
if state.query_type == KGRelationshipDetection.RELATIONSHIPS.value:
# Correction if needed:
logger.debug(f"A3 - sql_statement: {sql_statement}")
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
# Correction if needed:
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
)
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
_drop_temp_views(
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
raise e
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
# Get SQL for source documents
source_documents_sql = None
source_documents_sql = _get_source_documents(
sql_statement,
llm=primary_llm,
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
if (
state.query_type == KGRelationshipDetection.RELATIONSHIPS.value
or _sql_is_aggregate_query(sql_statement)
):
source_documents_sql = _get_source_documents(
sql_statement,
llm=primary_llm,
focus=state.query_type,
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
)
if source_documents_sql and ent_temp_view:
source_documents_sql = source_documents_sql.replace(
"entity_table", ent_temp_view
)
if source_documents_sql and rel_temp_view:
source_documents_sql = source_documents_sql.replace(
"relationship_table", rel_temp_view
)
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
logger.info(f"A3 source_documents_sql: {source_documents_sql}")
scalar_result = None
query_results = None
# check sql, just in case
_raise_error_if_sql_fails_problem_test(
sql_statement, rel_temp_view, ent_temp_view
)
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
@@ -440,16 +340,8 @@ def generate_simple_sql(
raise e
source_document_results = None
if source_documents_sql is not None and source_documents_sql != sql_statement:
# check source document sql, just in case
_raise_error_if_sql_fails_problem_test(
source_documents_sql, rel_temp_view, ent_temp_view
)
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(source_documents_sql))
rows = result.fetchall()
@@ -460,29 +352,18 @@ def generate_simple_sql(
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
# TODO: raise error on frontend
logger.error(f"Error executing Individualized SQL query: {e}")
# individualized_query_results = None
else:
source_document_results = None
if state.query_type == KGRelationshipDetection.NO_RELATIONSHIPS.value:
# source documents should be returned for entity queries
source_document_results = [
x["source_document"]
for x in query_results
if "source_document" in x
]
else:
source_document_results = None
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
_drop_temp_views(
allowed_docs_view_name=state.kg_doc_temp_view_name,
kg_relationships_view_name=state.kg_rel_temp_view_name,
)
logger.debug(f"A3 - Number of query_results: {len(query_results)}")
logger.info(f"A3 - Number of query_results: {len(query_results)}")
# Stream out reasoning and SQL query

Some files were not shown because too many files have changed in this diff Show More