Compare commits

...

2 Commits

Author SHA1 Message Date
Raunak Bhagat
b136b04925 fix: Fix Confluence pagination (#5035)
* Re-implement pagination

* Add note

* Fix invalid integration test configs

* Fix other failing test

* Edit failing test

* Revert test

* Revert pagination size

* Add comment on yielding style

* Use fixture instead of manually initializing sql-engine

* Fix failing tests

* Move code back and copy-paste
2025-07-17 23:54:25 -07:00
Chris Weaver
9f2b0723a8 Improve support for non-default postgres schemas (#5046) 2025-07-17 23:53:20 -07:00
16 changed files with 279 additions and 42 deletions

View File

@@ -23,7 +23,7 @@ from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
POSTGRES_DEFAULT_SCHEMA,
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
@@ -271,7 +271,7 @@ async def run_async_migrations() -> None:
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
schemas = [POSTGRES_DEFAULT_SCHEMA]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)

View File

@@ -9,7 +9,7 @@ Create Date: 2025-06-22 17:33:25.833733
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
@@ -66,7 +66,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -111,7 +111,7 @@ def upgrade() -> None:
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# revision identifiers, used by Alembic.
@@ -478,7 +478,7 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
@@ -518,7 +518,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -563,7 +563,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -1,3 +1,17 @@
"""
# README (notes on Confluence pagination):
We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as
opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination.
Our default pagination strategy right now for cloud is to assume cursor-based.
However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the
returned payload), then you can force offset-based pagination.
# TODO (@raunakab)
We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out.
"""
import json
import time
from collections.abc import Callable
@@ -46,15 +60,13 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
class ConfluenceRateLimitError(Exception):
pass
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence:
"""
This is a custom Confluence class that:
@@ -463,6 +475,7 @@ class OnyxConfluence:
limit: int | None = None,
# Called with the next url to use to get the next page
next_page_callback: Callable[[str], None] | None = None,
force_offset_pagination: bool = False,
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -548,14 +561,32 @@ class OnyxConfluence:
)
raise e
# yield the results individually
# Yield the results individually.
results = cast(list[dict[str, Any]], next_response.get("results", []))
# make sure we don't update the start by more than the amount
# Note 1:
# Make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
#
# Note 2:
# We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs`
# because we *have to call the `next_page_callback`* prior to yielding the last element!
#
# If we did:
#
# ```py
# yield from results
# if next_page_callback:
# next_page_callback(url_suffix)
# ```
#
# then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving
# loop) prior to the callback being called.
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
@@ -571,6 +602,12 @@ class OnyxConfluence:
)
# notify the caller of the new url
next_page_callback(url_suffix)
elif force_offset_pagination and i == len(results) - 1:
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
yield result
# we've observed that Confluence sometimes returns a next link despite giving
@@ -668,7 +705,9 @@ class OnyxConfluence:
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
for user_result in self._paginate_url(
url, limit, force_offset_pagination=True
):
# Example response:
# {
# 'user': {
@@ -758,7 +797,7 @@ class OnyxConfluence:
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
yield from self._paginate_url(url, limit, force_offset_pagination=True)
def paginated_groups_retrieval(
self,

View File

@@ -29,6 +29,7 @@ from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.sql_engine import USE_IAM_AUTH
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.contextvars import get_current_tenant_id
@@ -118,7 +119,7 @@ async def get_async_session(
engine = get_sqlalchemy_async_engine()
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
async with AsyncSession(bind=engine, expire_on_commit=False) as session:
yield session
return

View File

@@ -31,6 +31,7 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -324,7 +325,7 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
with Session(bind=engine, expire_on_commit=False) as session:
yield session
return
@@ -370,12 +371,11 @@ def get_db_readonly_user_session_with_current_tenant() -> (
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# no need to use the schema translation map for self-hosted + default schema
if not MULTI_TENANT:
if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
with Session(readonly_engine, expire_on_commit=False) as session:
yield session
return
# no need to use the schema translation map for self-hosted + default schema
schema_translate_map = {None: tenant_id}
with readonly_engine.connect().execution_options(
schema_translate_map=schema_translate_map

View File

@@ -34,7 +34,7 @@ from onyx.kg.models import KGGroundingType
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -180,7 +180,7 @@ def _cluster_one_grounded_entity(
# find entities of the same type with a similar name
*filtering,
KGEntity.entity_type_id_name == entity.entity_type_id_name,
getattr(func, POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE).similarity_op(
getattr(func, POSTGRES_DEFAULT_SCHEMA).similarity_op(
KGEntity.name, entity_name
),
)

View File

@@ -33,7 +33,7 @@ from onyx.kg.utils.formatting_utils import split_entity_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -95,7 +95,7 @@ def _normalize_one_entity(
# generate trigrams of the queried entity Q
query_trigrams = db_session.query(
getattr(func, POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE)
getattr(func, POSTGRES_DEFAULT_SCHEMA)
.show_trgm(cleaned_entity)
.cast(ARRAY(String(3)))
.label("trigrams")

View File

@@ -140,6 +140,8 @@ else:
# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
# Outside this file, should almost always use `POSTGRES_DEFAULT_SCHEMA` unless you
# have a very good reason
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE = "public"
POSTGRES_DEFAULT_SCHEMA = (
os.environ.get("POSTGRES_DEFAULT_SCHEMA") or POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE

View File

@@ -6,6 +6,7 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from onyx.db.engine.sql_engine import SqlEngine
from onyx.main import fetch_versioned_implementation
from onyx.utils.logger import setup_logger
@@ -23,3 +24,12 @@ def client() -> Generator[TestClient, Any, None]:
)()
client = TestClient(app)
yield client
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)

View File

@@ -0,0 +1,35 @@
import os
from typing import Any
import pytest
@pytest.fixture
def confluence_connector_config() -> dict[str, Any]:
url_base = os.environ.get("CONFLUENCE_TEST_SPACE_URL")
space_key = os.environ.get("CONFLUENCE_SPACE_KEY")
page_id = os.environ.get("CONFLUENCE_PAGE_ID")
is_cloud = os.environ.get("CONFLUENCE_IS_CLOUD", "").lower() == "true"
assert url_base, "CONFLUENCE_URL environment variable is required"
return {
"wiki_base": url_base,
"is_cloud": is_cloud,
"space": space_key or "",
"page_id": page_id or "",
}
@pytest.fixture
def confluence_credential_json() -> dict[str, Any]:
username = os.environ.get("CONFLUENCE_USER_NAME")
access_token = os.environ.get("CONFLUENCE_ACCESS_TOKEN")
assert username, "CONFLUENCE_USERNAME environment variable is required"
assert access_token, "CONFLUENCE_ACCESS_TOKEN environment variable is required"
return {
"confluence_username": username,
"confluence_access_token": access_token,
}

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
class ExternalUserGroupSet(BaseModel):
"""A version of ExternalUserGroup that uses a set for user_emails to avoid order-dependent comparisons."""
id: str
user_emails: set[str]
gives_anyone_access: bool
@classmethod
def from_model(
cls, external_user_group: ExternalUserGroup
) -> "ExternalUserGroupSet":
"""Convert from ExternalUserGroup to ExternalUserGroupSet."""
return cls(
id=external_user_group.id,
user_emails=set(external_user_group.user_emails),
gives_anyone_access=external_user_group.gives_anyone_access,
)

View File

@@ -0,0 +1,133 @@
from typing import Any
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from shared_configs.contextvars import get_current_tenant_id
from tests.daily.connectors.confluence.models import ExternalUserGroupSet
# In order to get these tests to run, use the credentials from Bitwarden.
# Search up "ENV vars for local and Github tests", and find the Confluence relevant key-value pairs.
_EXPECTED_CONFLUENCE_GROUPS = [
ExternalUserGroupSet(
id="confluence-admins-danswerai",
user_emails={"chris@onyx.app", "yuhong@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="org-admins",
user_emails={"founders@onyx.app", "chris@onyx.app", "yuhong@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="confluence-users-danswerai",
user_emails={
"chris@onyx.app",
"hagen@danswer.ai",
"founders@onyx.app",
"pablo@onyx.app",
"yuhong@onyx.app",
},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="jira-users-danswerai",
user_emails={
"hagen@danswer.ai",
"founders@onyx.app",
"pablo@onyx.app",
"chris@onyx.app",
},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="jira-admins-danswerai",
user_emails={"hagen@danswer.ai", "founders@onyx.app", "pablo@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="confluence-user-access-admins-danswerai",
user_emails={"hagen@danswer.ai"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="jira-user-access-admins-danswerai",
user_emails={"hagen@danswer.ai"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="Yuhong Only No Chris Allowed",
user_emails={"yuhong@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="All_Confluence_Users_Found_By_Onyx",
user_emails={
"chris@onyx.app",
"hagen@danswer.ai",
"founders@onyx.app",
"pablo@onyx.app",
"yuhong@onyx.app",
},
gives_anyone_access=False,
),
]
def test_confluence_group_sync(
initialize_db: None,
confluence_connector_config: dict[str, Any],
confluence_credential_json: dict[str, Any],
) -> None:
with get_session_with_current_tenant() as db_session:
connector = Connector(
name="Test Connector",
source=DocumentSource.CONFLUENCE,
input_type=InputType.POLL,
connector_specific_config=confluence_connector_config,
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
db_session.add(connector)
db_session.flush()
credential = Credential(
source=DocumentSource.CONFLUENCE,
credential_json=confluence_credential_json,
)
db_session.add(credential)
db_session.flush()
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,
credential_id=credential.id,
name="Test CC Pair",
status=ConnectorCredentialPairStatus.ACTIVE,
access_type=AccessType.SYNC,
auto_sync_options=None,
)
db_session.add(cc_pair)
db_session.commit()
db_session.refresh(cc_pair)
tenant_id = get_current_tenant_id()
group_sync_iter = confluence_group_sync(
tenant_id=tenant_id,
cc_pair=cc_pair,
)
expected_groups = {group.id: group for group in _EXPECTED_CONFLUENCE_GROUPS}
actual_groups = {
group.id: ExternalUserGroupSet.from_model(external_user_group=group)
for group in group_sync_iter
}
assert expected_groups == actual_groups

View File

@@ -117,5 +117,7 @@ def test_paginated_cql_user_retrieval_no_overrides_cloud() -> None:
# Check that the cloud-specific user search URL is called
mock_paginate.assert_called_once_with(
"rest/api/search/user?cql=type=user", None
"rest/api/search/user?cql=type=user",
None,
force_offset_pagination=True,
)

View File

@@ -19,6 +19,15 @@ from tests.integration.common_utils.vespa import vespa_fixture
BASIC_USER_NAME = "basic_user"
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
@@ -45,19 +54,6 @@ errors.
Commenting out till we can get to the bottom of it. For now, just using
instantiate the session directly within the test.
"""
# @pytest.fixture
# def db_session() -> Generator[Session, None, None]:
# with get_session_with_current_tenant() as session:
# yield session
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)
@pytest.fixture

View File

@@ -37,9 +37,8 @@ def test_overlapping_connector_creation(reset: None) -> None:
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "DailyConne",
"space": "DailyConnectorTestSpace",
"is_cloud": True,
"page_id": "",
}
credential = {
@@ -98,9 +97,7 @@ def test_connector_pause_while_indexing(reset: None) -> None:
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "",
"is_cloud": True,
"page_id": "",
}
credential = {