mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 08:45:47 +00:00
Compare commits
23 Commits
bugfixfix
...
content-mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68aa00e330 | ||
|
|
3a78421e38 | ||
|
|
0aa9e8968a | ||
|
|
7d9e133e35 | ||
|
|
838160e660 | ||
|
|
6b84332f1b | ||
|
|
4fe5561f44 | ||
|
|
2d81d6082a | ||
|
|
ef291fcf0c | ||
|
|
b8f64d10a2 | ||
|
|
9f37ca23e8 | ||
|
|
34b2e5d9d3 | ||
|
|
ff6e4cd231 | ||
|
|
9d7137b3bb | ||
|
|
91cdcd820f | ||
|
|
4ff0c18822 | ||
|
|
66976529d2 | ||
|
|
1954105cc6 | ||
|
|
625ea6f24b | ||
|
|
60d2c8c86c | ||
|
|
324b8e42a5 | ||
|
|
125877ec65 | ||
|
|
ca803859cc |
@@ -45,16 +45,11 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -8,7 +8,7 @@ Edition features outside of personal development or testing purposes. Please rea
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
|
||||
@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
|
||||
visit https://github.com/onyx-dot-app/onyx."
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = INFO
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -25,9 +25,6 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
@@ -39,7 +36,6 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -68,7 +64,7 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -80,10 +76,6 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
# continue on error with individual tenant
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -94,12 +86,14 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
@@ -140,12 +134,7 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
@@ -162,15 +151,9 @@ async def run_async_migrations() -> None:
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
@@ -179,12 +162,7 @@ async def run_async_migrations() -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
@@ -202,11 +180,7 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
@@ -256,7 +230,6 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Rules defined here:
|
||||
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
|
||||
"""
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
@@ -264,11 +263,13 @@ def _fetch_all_page_restrictions(
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -285,9 +286,11 @@ def _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
@@ -321,9 +324,11 @@ def _fetch_all_page_restrictions(
|
||||
continue
|
||||
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
@@ -337,12 +342,13 @@ def _fetch_all_page_restrictions(
|
||||
)
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -381,7 +387,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
yield from _fetch_all_page_restrictions(
|
||||
return _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
@@ -35,7 +34,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -49,6 +48,7 @@ def gmail_doc_sync(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -60,14 +60,17 @@ def gmail_doc_sync(
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
if user_email := slim_doc.perm_sync_data.get("user_email"):
|
||||
ext_access = ExternalAccess(
|
||||
external_user_emails=set([user_email]),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
document_external_access.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
)
|
||||
|
||||
return document_external_access
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -148,7 +147,7 @@ def _get_permissions_from_slim_doc(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -162,6 +161,7 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
@@ -174,7 +174,10 @@ def gdrive_doc_sync(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
yield DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
@@ -16,6 +14,35 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
if channel_id not in channel_doc_map:
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
def _fetch_workspace_permissions(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
@@ -95,37 +122,10 @@ def _fetch_channel_permissions(
|
||||
return channel_permissions
|
||||
|
||||
|
||||
def _get_slack_document_access(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
yield DocExternalAccess(
|
||||
external_access=channel_permissions[channel_id],
|
||||
doc_id=doc_metadata.id,
|
||||
)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("_get_slack_document_access: Stop signal detected")
|
||||
|
||||
callback.progress("_get_slack_document_access", 1)
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -136,12 +136,9 @@ def slack_doc_sync(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
if not user_id_to_email_map:
|
||||
raise ValueError(
|
||||
"No user id to email map found. Please check to make sure that "
|
||||
"your Slack bot token has the `users:read.email` scope"
|
||||
)
|
||||
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
@@ -151,8 +148,18 @@ def slack_doc_sync(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
yield from _get_slack_document_access(
|
||||
cc_pair=cc_pair,
|
||||
channel_permissions=channel_permissions,
|
||||
callback=callback,
|
||||
)
|
||||
document_external_accesses = []
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
# No documents found for channel the channel_id
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
@@ -24,7 +23,7 @@ DocSyncFuncType = Callable[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
|
||||
@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from ee.onyx.server.enterprise_settings.store import get_logo_filename
|
||||
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
|
||||
from ee.onyx.server.enterprise_settings.store import load_analytics_script
|
||||
from ee.onyx.server.enterprise_settings.store import load_settings
|
||||
from ee.onyx.server.enterprise_settings.store import store_analytics_script
|
||||
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||
@@ -131,49 +131,31 @@ def put_logo(
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
file_store = get_default_file_store(db_session)
|
||||
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
|
||||
file_io = file_store.read_file(filename, mode="b")
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No logo file found",
|
||||
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No logotype file found",
|
||||
)
|
||||
else:
|
||||
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
|
||||
|
||||
|
||||
@basic_router.get("/logotype")
|
||||
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
|
||||
return fetch_logotype_helper(db_session)
|
||||
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
|
||||
|
||||
|
||||
@basic_router.get("/logo")
|
||||
def fetch_logo(
|
||||
is_logotype: bool = False, db_session: Session = Depends(get_session)
|
||||
) -> Response:
|
||||
if is_logotype:
|
||||
return fetch_logotype_helper(db_session)
|
||||
|
||||
return fetch_logo_helper(db_session)
|
||||
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
|
||||
@@ -13,7 +13,6 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
|
||||
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -22,18 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def load_settings() -> EnterpriseSettings:
|
||||
"""Loads settings data directly from DB. This should be used primarily
|
||||
for checking what is actually in the DB, aka for editing and saving back settings.
|
||||
|
||||
Runtime settings actually used by the application should be checked with
|
||||
load_runtime_settings as defaults may be applied at runtime.
|
||||
"""
|
||||
|
||||
dynamic_config_store = get_kv_store()
|
||||
try:
|
||||
settings = EnterpriseSettings(
|
||||
@@ -47,24 +36,9 @@ def load_settings() -> EnterpriseSettings:
|
||||
|
||||
|
||||
def store_settings(settings: EnterpriseSettings) -> None:
|
||||
"""Stores settings directly to the kv store / db."""
|
||||
|
||||
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
|
||||
def load_runtime_settings() -> EnterpriseSettings:
|
||||
"""Loads settings from DB and applies any defaults or transformations for use
|
||||
at runtime.
|
||||
|
||||
Should not be stored back to the DB.
|
||||
"""
|
||||
enterprise_settings = load_settings()
|
||||
if not enterprise_settings.application_name:
|
||||
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
return enterprise_settings
|
||||
|
||||
|
||||
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
|
||||
|
||||
|
||||
@@ -86,6 +60,10 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
|
||||
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
|
||||
|
||||
|
||||
_LOGO_FILENAME = "__logo__"
|
||||
_LOGOTYPE_FILENAME = "__logotype__"
|
||||
|
||||
|
||||
def is_valid_file_type(filename: str) -> bool:
|
||||
valid_extensions = (".png", ".jpg", ".jpeg")
|
||||
return filename.endswith(valid_extensions)
|
||||
@@ -138,11 +116,3 @@ def upload_logo(
|
||||
file_type=file_type,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_logo_filename() -> str:
|
||||
return _LOGO_FILENAME
|
||||
|
||||
|
||||
def get_logotype_filename() -> str:
|
||||
return _LOGOTYPE_FILENAME
|
||||
|
||||
@@ -87,14 +87,11 @@ async def get_or_provision_tenant(
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
return tenant_id
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
@@ -119,6 +116,10 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
@@ -270,7 +271,6 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -283,7 +283,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
@@ -291,10 +291,9 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
@@ -560,3 +559,7 @@ async def assign_tenant_to_user(
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -65,17 +65,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
try:
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
|
||||
"This is not a critical error and the model server will continue to run."
|
||||
)
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
@@ -20,7 +20,7 @@ class ExternalAccess:
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Vespa.
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.image import MIMEImage
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
@@ -14,13 +13,8 @@ from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
@@ -103,8 +97,8 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="cid:logo.png"
|
||||
alt="{application_name} Logo"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
@@ -119,8 +113,9 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -130,27 +125,17 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
|
||||
|
||||
def build_html_email(
|
||||
application_name: str | None,
|
||||
heading: str,
|
||||
message: str,
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
application_name=application_name,
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
@@ -161,7 +146,6 @@ def send_email(
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
inline_png: tuple[str, bytes] | None = None,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
@@ -180,12 +164,6 @@ def send_email(
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
|
||||
if inline_png:
|
||||
img = MIMEImage(inline_png[1], _subtype="png")
|
||||
img.add_header("Content-ID", inline_png[0]) # CID reference
|
||||
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
|
||||
msg.attach(img)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
@@ -196,21 +174,8 @@ def send_email(
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
"""This is templated but isn't meaningful for whitelabeling."""
|
||||
|
||||
# Example usage of the reusable HTML
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Your {application_name} Subscription Has Been Canceled"
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
@@ -219,48 +184,23 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"Invitation to Join {application_name} Organization"
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
heading = "You've Been Invited!"
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
@@ -286,32 +226,19 @@ def send_user_email_invite(
|
||||
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -321,36 +248,14 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Forgot Password"
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Reset Your Password",
|
||||
message,
|
||||
)
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
@@ -359,33 +264,11 @@ def send_user_verification_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
try:
|
||||
load_runtime_settings_fn = fetch_versioned_implementation(
|
||||
"onyx.server.enterprise_settings.store", "load_runtime_settings"
|
||||
)
|
||||
settings = load_runtime_settings_fn()
|
||||
application_name = settings.application_name
|
||||
except ModuleNotFoundError:
|
||||
application_name = ONYX_DEFAULT_APPLICATION_NAME
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Email Verification"
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Verify Your Email",
|
||||
message,
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content,
|
||||
mail_from,
|
||||
inline_png=("logo.png", onyx_file.data),
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
@@ -105,7 +105,6 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -594,7 +593,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
POSTGRES_DEFAULT_SCHEMA,
|
||||
None,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
|
||||
@@ -194,16 +194,6 @@ if not MULTI_TENANT:
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-process-memory",
|
||||
"task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -30,9 +30,6 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -443,14 +440,6 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -419,7 +420,12 @@ def connector_permission_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -447,23 +453,23 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses = doc_sync_func(cc_pair, callback)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
tasks_generated = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
tasks_generated += 1
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -401,7 +402,12 @@ def connector_external_group_sync_generator_task(
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
# TODO: add some notification to the admins here
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
@@ -419,9 +425,12 @@ def connector_external_group_sync_generator_task(
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
import psutil
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
@@ -20,7 +19,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -41,10 +39,8 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
@@ -908,93 +904,3 @@ def monitor_celery_queues_helper(
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
|
||||
"""Memory monitoring"""
|
||||
|
||||
|
||||
def _get_cmdline_for_process(process: psutil.Process) -> str | None:
|
||||
try:
|
||||
return " ".join(process.cmdline())
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_process_memory(self: Task, *, tenant_id: str) -> None:
|
||||
"""
|
||||
Task to monitor memory usage of supervisor-managed processes.
|
||||
This periodically checks the memory usage of processes and logs information
|
||||
in a standardized format.
|
||||
|
||||
The task looks for processes managed by supervisor and logs their
|
||||
memory usage statistics. This is useful for monitoring memory consumption
|
||||
over time and identifying potential memory leaks.
|
||||
"""
|
||||
# don't run this task in multi-tenant mode, have other, better means of monitoring
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get all supervisor-managed processes
|
||||
supervisor_processes: dict[int, str] = {}
|
||||
|
||||
# Map cmd line elements to more readable process names
|
||||
process_type_mapping = {
|
||||
"--hostname=primary": "primary",
|
||||
"--hostname=light": "light",
|
||||
"--hostname=heavy": "heavy",
|
||||
"--hostname=indexing": "indexing",
|
||||
"--hostname=monitoring": "monitoring",
|
||||
"beat": "beat",
|
||||
"slack/listener.py": "slack",
|
||||
}
|
||||
|
||||
# Find all python processes that are likely celery workers
|
||||
for proc in psutil.process_iter():
|
||||
cmdline = _get_cmdline_for_process(proc)
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Match supervisor-managed processes
|
||||
for process_name, process_type in process_type_mapping.items():
|
||||
if process_name in cmdline:
|
||||
if process_type in supervisor_processes.values():
|
||||
task_logger.error(
|
||||
f"Duplicate process type for type {process_type} "
|
||||
f"with cmd {cmdline} with pid={proc.pid}."
|
||||
)
|
||||
continue
|
||||
|
||||
supervisor_processes[proc.pid] = process_type
|
||||
break
|
||||
|
||||
if len(supervisor_processes) != len(process_type_mapping):
|
||||
task_logger.error(
|
||||
"Missing processes: "
|
||||
f"{set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
|
||||
)
|
||||
|
||||
# Log memory usage for each process
|
||||
for pid, process_type in supervisor_processes.items():
|
||||
try:
|
||||
emit_process_memory(pid, process_type, {})
|
||||
except psutil.NoSuchProcess:
|
||||
# Process may have terminated since we obtained the list
|
||||
continue
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error monitoring process {pid}: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in monitor_process_memory task")
|
||||
|
||||
@@ -6,8 +6,6 @@ from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -18,6 +16,7 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.object_size_check import deep_getsizeof
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
|
||||
@@ -53,7 +52,7 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, connector: BaseConnector
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> ConnectorCheckpoint | None:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
@@ -61,8 +60,6 @@ def load_checkpoint(
|
||||
try:
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointConnector):
|
||||
return connector.validate_checkpoint_json(checkpoint_data)
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
except RuntimeError:
|
||||
return None
|
||||
@@ -74,7 +71,6 @@ def get_latest_valid_checkpoint(
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
connector: BaseConnector,
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
@@ -109,7 +105,7 @@ def get_latest_valid_checkpoint(
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return connector.build_dummy_checkpoint()
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
@@ -117,13 +113,12 @@ def get_latest_valid_checkpoint(
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
connector=connector,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
@@ -198,7 +193,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.model_dump())
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -31,11 +32,8 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
@@ -48,6 +46,8 @@ from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
@@ -387,7 +387,6 @@ def _run_indexing(
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
@@ -406,7 +405,7 @@ def _run_indexing(
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
@@ -414,7 +413,6 @@ def _run_indexing(
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
@@ -435,7 +433,7 @@ def _run_indexing(
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
@@ -598,44 +596,16 @@ def _run_indexing(
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import fast_gpu_status_request
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,9 +88,7 @@ class Answer:
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = (
|
||||
fast_gpu_status_request(indexing=False) or using_cloud_reranking
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
|
||||
@@ -33,10 +33,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
# Controls whether to allow admin query history reports with:
|
||||
# 1. associated user emails
|
||||
# 2. anonymized user emails
|
||||
# 3. no queries
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
@@ -157,9 +153,10 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
|
||||
|
||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
@@ -344,8 +341,8 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
|
||||
HtmlBasedConnectorTransformLinksStrategy.STRIP,
|
||||
)
|
||||
|
||||
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
|
||||
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
@@ -388,10 +385,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -421,9 +414,6 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,10 +3,6 @@ import socket
|
||||
from enum import auto
|
||||
from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -44,7 +40,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Onyx as a search engine."
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
@@ -179,7 +174,6 @@ class DocumentSource(str, Enum):
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
HIGHSPOT = "highspot"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
@@ -394,9 +388,6 @@ class OnyxCeleryTask:
|
||||
)
|
||||
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -411,7 +402,9 @@ class OnyxCeleryTask:
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -114,7 +114,6 @@ class ConfluenceConnector(
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@@ -159,9 +158,6 @@ class ConfluenceConnector(
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -237,9 +233,7 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
@@ -270,7 +264,6 @@ class ConfluenceConnector(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
@@ -311,14 +304,13 @@ class ConfluenceConnector(
|
||||
if "version" in page and "by" in page["version"]:
|
||||
author = page["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners.append(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
)
|
||||
primary_owners.append(BasicExpertInfo(display_name=display_name))
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=page_url,
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -381,7 +373,6 @@ class ConfluenceConnector(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
@@ -498,12 +498,10 @@ class OnyxConfluence:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.debug(
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
|
||||
"but we have logic to work around it - don't worry this isn't"
|
||||
f" causing an issue. Start: {new_start}, Previous Start: "
|
||||
f"{previous_start}, Len Results: {len(results)}."
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
|
||||
@@ -112,7 +112,6 @@ def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
allow_images: bool,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
@@ -120,7 +119,7 @@ def process_attachment(
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return AttachmentProcessingResult(
|
||||
@@ -139,14 +138,7 @@ def process_attachment(
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
if not allow_images:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error="Image downloading is not enabled",
|
||||
)
|
||||
else:
|
||||
if not media_type.startswith("image/"):
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
@@ -302,7 +294,6 @@ def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
allow_images: bool,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
@@ -318,7 +309,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_id, allow_images)
|
||||
result = process_attachment(confluence_client, attachment, page_id)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -2,8 +2,6 @@ import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
@@ -21,10 +19,8 @@ logger = setup_logger()
|
||||
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
@@ -33,20 +29,20 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: CT | None = None
|
||||
self.next_checkpoint: ConnectorCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> CheckpointOutput[CT]:
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> CheckpointOutput:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
@@ -68,7 +64,7 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner(Generic[CT]):
|
||||
class ConnectorRunner:
|
||||
"""
|
||||
Handles:
|
||||
- Batching
|
||||
@@ -89,9 +85,11 @@ class ConnectorRunner(Generic[CT]):
|
||||
self.doc_batch: list[Document] = []
|
||||
|
||||
def run(
|
||||
self, checkpoint: CT
|
||||
self, checkpoint: ConnectorCheckpoint
|
||||
) -> Generator[
|
||||
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
|
||||
tuple[
|
||||
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
@@ -107,9 +105,9 @@ class ConnectorRunner(Generic[CT]):
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
next_checkpoint: CT | None = None
|
||||
next_checkpoint: ConnectorCheckpoint | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
@@ -134,7 +132,7 @@ class ConnectorRunner(Generic[CT]):
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = self.connector.build_dummy_checkpoint()
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -31,7 +30,6 @@ from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
@@ -119,7 +117,6 @@ def identify_connector_class(
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
@@ -185,8 +182,6 @@ def instantiate_connector(
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,34 +219,24 @@ def _process_file(
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images.
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Extract text and images from the file
|
||||
extraction_result = extract_text_and_images(
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
)
|
||||
|
||||
# Merge file-specific metadata (from file content) with provided metadata
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
|
||||
)
|
||||
metadata.update(extraction_result.metadata)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if extraction_result.text_content.strip():
|
||||
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
|
||||
sections.append(
|
||||
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
|
||||
)
|
||||
if text_content.strip():
|
||||
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
extraction_result.embedded_images, start=1
|
||||
):
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image reference
|
||||
try:
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -15,30 +13,26 @@ from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ITEMS_PER_PAGE = 100
|
||||
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
@@ -54,7 +48,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
|
||||
def _get_batch_rate_limited(
|
||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||
) -> list[PullRequest | Issue]:
|
||||
) -> list[Any]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
@@ -75,6 +69,21 @@ def _get_batch_rate_limited(
|
||||
)
|
||||
|
||||
|
||||
def _batch_github_objects(
|
||||
git_objs: PaginatedList, github_client: Github, batch_size: int
|
||||
) -> Iterator[list[Any]]:
|
||||
page_num = 0
|
||||
while True:
|
||||
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
|
||||
page_num += 1
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for mini_batch in batch_generator(batch, batch_size=batch_size):
|
||||
yield mini_batch
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
@@ -86,9 +95,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None,
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata={
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
@@ -115,58 +122,31 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
ISSUES = "issues"
|
||||
|
||||
|
||||
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||
stage: GithubConnectorStage
|
||||
curr_page: int
|
||||
|
||||
cached_repo_ids: list[int] | None = None
|
||||
cached_repo: SerializedRepository | None = None
|
||||
|
||||
|
||||
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
self.include_issues = include_issues
|
||||
self.github_client: Github | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# defaults to 30 items per page, can be set to as high as 100
|
||||
self.github_client = (
|
||||
Github(
|
||||
credentials["github_access_token"],
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||
else Github(credentials["github_access_token"])
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -237,193 +217,85 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
|
||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=checkpoint.cached_repo_ids[0],
|
||||
headers=repos[0].raw_headers,
|
||||
raw_data=repos[0].raw_data,
|
||||
)
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
# save checkpoint with repo ids retrieved
|
||||
return checkpoint
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
|
||||
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
pull_requests, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_prs = False
|
||||
for pr in pr_batch:
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
continue
|
||||
try:
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
|
||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||
if not done_with_prs and len(pr_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# prs to get, we move on to issues
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch = []
|
||||
issue_batch = _get_batch_rate_limited(
|
||||
issues, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
for issue in cast(list[Issue], issue_batch):
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any issues on the page, yield them and return the checkpoint
|
||||
if not done_with_issues and len(issue_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# issues to get, we move on to the next repo
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||
if checkpoint.cached_repo_ids:
|
||||
next_id = checkpoint.cached_repo_ids.pop()
|
||||
next_repo = self.github_client.get_repo(next_id)
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=next_id,
|
||||
headers=next_repo.raw_headers,
|
||||
raw_data=next_repo.raw_data,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
|
||||
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
||||
# Could be due to delayed processing on GitHub side
|
||||
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
||||
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
||||
|
||||
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
if adjusted_start_datetime < epoch:
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
)
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
@@ -525,16 +397,6 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -544,9 +406,7 @@ if __name__ == "__main__":
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, time.time(), connector.build_dummy_checkpoint()
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -14,9 +13,7 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.section_extraction import get_document_sections
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -76,10 +73,9 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
@@ -88,10 +84,6 @@ def _download_and_extract_sections_basic(
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
@@ -210,17 +202,12 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
doc_id = ""
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
try:
|
||||
# skip shortcuts or folders
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
@@ -228,11 +215,13 @@ def convert_drive_item_to_document(
|
||||
return None
|
||||
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
# Try to get sections using the advanced method first
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
# get_document_sections is the advanced approach for Google Docs
|
||||
doc_sections = get_document_sections(
|
||||
docs_service=docs_service(), doc_id=file.get("id", "")
|
||||
docs_service=docs_service, doc_id=file.get("id", "")
|
||||
)
|
||||
if doc_sections:
|
||||
sections = cast(list[TextSection | ImageSection], doc_sections)
|
||||
@@ -241,24 +230,9 @@ def convert_drive_item_to_document(
|
||||
f"Error in advanced parsing: {e}. Falling back to basic extraction."
|
||||
)
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logger.warning(
|
||||
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _download_and_extract_sections_basic(
|
||||
file, drive_service(), allow_images
|
||||
)
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
@@ -283,19 +257,8 @@ def convert_drive_item_to_document(
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
|
||||
logger.exception(error_str)
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=sections[0].link
|
||||
if sections
|
||||
else None, # TODO: see if this is the best way to get a link
|
||||
),
|
||||
failed_entity=None,
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
logger.error(f"Error converting file {file.get('name')}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
|
||||
@@ -1,23 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.models import RetrievedDriveFile
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.google_utils import GoogleFields
|
||||
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
@@ -37,13 +31,11 @@ def _generate_time_range_filter(
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += (
|
||||
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
|
||||
)
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime >= '{time_start}'"
|
||||
if end is not None:
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_range_filter += f" and modifiedTime <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
@@ -74,9 +66,9 @@ def _get_folders_in_parent(
|
||||
def _get_files_in_parent(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
is_slim: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
|
||||
query += " and trashed = false"
|
||||
@@ -91,7 +83,6 @@ def _get_files_in_parent(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=query,
|
||||
**({} if is_slim else {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}),
|
||||
):
|
||||
yield file
|
||||
|
||||
@@ -99,50 +90,30 @@ def _get_files_in_parent(
|
||||
def crawl_folders_for_files(
|
||||
service: Resource,
|
||||
parent_id: str,
|
||||
is_slim: bool,
|
||||
user_email: str,
|
||||
traversed_parent_ids: set[str],
|
||||
update_traversed_ids_func: Callable[[str], None],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[RetrievedDriveFile]:
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
"""
|
||||
This function starts crawling from any folder. It is slower though.
|
||||
"""
|
||||
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
|
||||
if parent_id not in traversed_parent_ids:
|
||||
logger.info("Parent id not in traversed parent ids, getting files")
|
||||
found_files = False
|
||||
file = {}
|
||||
try:
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
parent_id=parent_id,
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
found_files = True
|
||||
logger.info(f"Found file: {file['name']}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files in parent {parent_id}: {e}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
parent_id=parent_id,
|
||||
completion_stage=DriveRetrievalStage.FOLDER_FILES,
|
||||
error=e,
|
||||
)
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
else:
|
||||
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
|
||||
if parent_id in traversed_parent_ids:
|
||||
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
|
||||
return
|
||||
|
||||
found_files = False
|
||||
for file in _get_files_in_parent(
|
||||
service=service,
|
||||
start=start,
|
||||
end=end,
|
||||
parent_id=parent_id,
|
||||
):
|
||||
found_files = True
|
||||
yield file
|
||||
|
||||
if found_files:
|
||||
update_traversed_ids_func(parent_id)
|
||||
|
||||
for subfolder in _get_folders_in_parent(
|
||||
service=service,
|
||||
@@ -152,8 +123,6 @@ def crawl_folders_for_files(
|
||||
yield from crawl_folders_for_files(
|
||||
service=service,
|
||||
parent_id=subfolder["id"],
|
||||
is_slim=is_slim,
|
||||
user_email=user_email,
|
||||
traversed_parent_ids=traversed_parent_ids,
|
||||
update_traversed_ids_func=update_traversed_ids_func,
|
||||
start=start,
|
||||
@@ -164,19 +133,16 @@ def crawl_folders_for_files(
|
||||
def get_files_in_shared_drive(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
is_slim: bool,
|
||||
is_slim: bool = False,
|
||||
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
@@ -189,13 +155,15 @@ def get_files_in_shared_drive(
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(file["id"])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(drive_id)
|
||||
|
||||
# Get all files in the shared drive
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
@@ -205,26 +173,16 @@ def get_files_in_shared_drive(
|
||||
includeItemsFromAllDrives=True,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
):
|
||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
||||
# empty drives.
|
||||
update_traversed_ids_func(drive_id)
|
||||
yield file
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
service: GoogleDriveService,
|
||||
service: Any,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool,
|
||||
is_slim: bool = False,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
# If we know we are going to folder crawl later, we can cache the folders here
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
@@ -238,7 +196,7 @@ def get_all_files_in_my_drive(
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=folder_query,
|
||||
):
|
||||
update_traversed_ids_func(file[GoogleFields.ID])
|
||||
update_traversed_ids_func(file["id"])
|
||||
found_folders = True
|
||||
if found_folders:
|
||||
update_traversed_ids_func(get_root_folder_id(service))
|
||||
@@ -251,28 +209,22 @@ def get_all_files_in_my_drive(
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora="user",
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_all_files_for_oauth(
|
||||
service: GoogleDriveService,
|
||||
service: Any,
|
||||
include_files_shared_with_me: bool,
|
||||
include_my_drives: bool,
|
||||
# One of the above 2 should be true
|
||||
include_shared_drives: bool,
|
||||
is_slim: bool,
|
||||
is_slim: bool = False,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
kwargs = {}
|
||||
if not is_slim:
|
||||
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
|
||||
|
||||
should_get_all = (
|
||||
include_shared_drives and include_my_drives and include_files_shared_with_me
|
||||
)
|
||||
@@ -291,13 +243,11 @@ def get_all_files_for_oauth(
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
includeItemsFromAllDrives=should_get_all,
|
||||
supportsAllDrives=should_get_all,
|
||||
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
|
||||
q=file_query,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -305,8 +255,4 @@ def get_all_files_for_oauth(
|
||||
def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return (
|
||||
service.files()
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
return service.files().get(fileId="root", fields="id").execute()["id"]
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import field_serializer
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
@@ -29,128 +20,3 @@ class GDriveMimeType(str, Enum):
|
||||
|
||||
|
||||
GoogleDriveFileType = dict[str, Any]
|
||||
|
||||
|
||||
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
|
||||
|
||||
|
||||
# These correspond to The major stages of retrieval for google drive.
|
||||
# The stages for the oauth flow are:
|
||||
# get_all_files_for_oauth(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# crawl_folders_for_files()
|
||||
#
|
||||
# The stages for the service account flow are roughly:
|
||||
# get_all_user_emails(),
|
||||
# get_all_drive_ids(),
|
||||
# get_files_in_shared_drive(),
|
||||
# Then for each user:
|
||||
# get_files_in_my_drive()
|
||||
# get_files_in_shared_drive()
|
||||
# crawl_folders_for_files()
|
||||
class DriveRetrievalStage(str, Enum):
|
||||
START = "start"
|
||||
DONE = "done"
|
||||
# OAuth specific stages
|
||||
OAUTH_FILES = "oauth_files"
|
||||
|
||||
# Service account specific stages
|
||||
USER_EMAILS = "user_emails"
|
||||
MY_DRIVE_FILES = "my_drive_files"
|
||||
|
||||
# Used for both oauth and service account flows
|
||||
DRIVE_IDS = "drive_ids"
|
||||
SHARED_DRIVE_FILES = "shared_drive_files"
|
||||
FOLDER_FILES = "folder_files"
|
||||
|
||||
|
||||
class StageCompletion(BaseModel):
|
||||
"""
|
||||
Describes the point in the retrieval+indexing process that the
|
||||
connector is at. completed_until is the timestamp of the latest
|
||||
file that has been retrieved or error that has been yielded.
|
||||
Optional fields are used for retrieval stages that need more information
|
||||
for resuming than just the timestamp of the latest file.
|
||||
"""
|
||||
|
||||
stage: DriveRetrievalStage
|
||||
completed_until: SecondsSinceUnixEpoch
|
||||
completed_until_parent_id: str | None = None
|
||||
|
||||
# only used for shared drives
|
||||
processed_drive_ids: set[str] = set()
|
||||
|
||||
def update(
|
||||
self,
|
||||
stage: DriveRetrievalStage,
|
||||
completed_until: SecondsSinceUnixEpoch,
|
||||
completed_until_parent_id: str | None = None,
|
||||
) -> None:
|
||||
self.stage = stage
|
||||
self.completed_until = completed_until
|
||||
self.completed_until_parent_id = completed_until_parent_id
|
||||
|
||||
|
||||
class RetrievedDriveFile(BaseModel):
|
||||
"""
|
||||
Describes a file that has been retrieved from google drive.
|
||||
user_email is the email of the user that the file was retrieved
|
||||
by impersonating. If an error worthy of being reported is encountered,
|
||||
error should be set and later propagated as a ConnectorFailure.
|
||||
"""
|
||||
|
||||
# The stage at which this file was retrieved
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The file that was retrieved
|
||||
drive_file: GoogleDriveFileType
|
||||
|
||||
# The email of the user that the file was retrieved by impersonating
|
||||
user_email: str
|
||||
|
||||
# The id of the parent folder or drive of the file
|
||||
parent_id: str | None = None
|
||||
|
||||
# Any unexpected error that occurred while retrieving the file.
|
||||
# In particular, this is not used for 403/404 errors, which are expected
|
||||
# in the context of impersonating all the users to try to retrieve all
|
||||
# files from all their Drives and Folders.
|
||||
error: Exception | None = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# Checkpoint version of _retrieved_ids
|
||||
retrieved_folder_and_drive_ids: set[str]
|
||||
|
||||
# Describes the point in the retrieval+indexing process that the
|
||||
# checkpoint is at. when this is set to a given stage, the connector
|
||||
# has finished yielding all values from the previous stage.
|
||||
completion_stage: DriveRetrievalStage
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per user email.
|
||||
# StageCompletion is used to track the completion of each stage, but the
|
||||
# timestamp part is not used for folder crawling.
|
||||
completion_map: ThreadSafeDict[str, StageCompletion]
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
folder_ids_to_retrieve: list[str] | None = None
|
||||
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(
|
||||
self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any
|
||||
) -> dict[str, StageCompletion]:
|
||||
return completion_map._dict
|
||||
|
||||
@field_validator("completion_map", mode="before")
|
||||
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
|
||||
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
|
||||
return ThreadSafeDict(
|
||||
{k: StageCompletion.model_validate(v) for k, v in v.items()}
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
@@ -17,37 +16,20 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
#
|
||||
# NOTE: We previously tried to combat this here by adding a very
|
||||
# long retry period (~20 minutes of trying, one request a minute.)
|
||||
# This is no longer necessary due to checkpointing.
|
||||
add_retries = retry_builder(tries=5, max_delay=10)
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
ORDER_BY_KEY = "orderBy"
|
||||
|
||||
|
||||
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
|
||||
class GoogleFields(str, Enum):
|
||||
ID = "id"
|
||||
CREATED_TIME = "createdTime"
|
||||
MODIFIED_TIME = "modifiedTime"
|
||||
NAME = "name"
|
||||
SIZE = "size"
|
||||
PARENTS = "parents"
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 6
|
||||
max_attempts = 10
|
||||
attempt = 1
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Drive/Gmail API with the same key
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
@@ -108,11 +90,11 @@ def execute_paginated_retrieval(
|
||||
retrieval_function: The specific list function to call (e.g., service.files().list)
|
||||
**kwargs: Arguments to pass to the list function
|
||||
"""
|
||||
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
|
||||
next_page_token = ""
|
||||
while next_page_token is not None:
|
||||
request_kwargs = kwargs.copy()
|
||||
if next_page_token:
|
||||
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
|
||||
request_kwargs["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
results = retrieval_function(**request_kwargs).execute()
|
||||
@@ -135,7 +117,7 @@ def execute_paginated_retrieval(
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
|
||||
next_page_token = results.get("nextPageToken")
|
||||
if list_key:
|
||||
for item in results.get(list_key, []):
|
||||
yield item
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
Highspot connector package for Onyx.
|
||||
Enables integration with Highspot's knowledge base.
|
||||
"""
|
||||
@@ -1,280 +0,0 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.exceptions import RequestException
|
||||
from requests.exceptions import Timeout
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HighspotClientError(Exception):
|
||||
"""Base exception for Highspot API client errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class HighspotAuthenticationError(HighspotClientError):
|
||||
"""Exception raised for authentication errors."""
|
||||
|
||||
|
||||
class HighspotRateLimitError(HighspotClientError):
|
||||
"""Exception raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[str] = None):
|
||||
self.retry_after = retry_after
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class HighspotClient:
|
||||
"""
|
||||
Client for interacting with the Highspot API.
|
||||
|
||||
Uses basic authentication with provided key (username) and secret (password).
|
||||
Implements retry logic, error handling, and connection pooling.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api-su2.highspot.com/v1.0/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
base_url: str = BASE_URL,
|
||||
timeout: int = 30,
|
||||
max_retries: int = 3,
|
||||
backoff_factor: float = 0.5,
|
||||
status_forcelist: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the Highspot API client.
|
||||
|
||||
Args:
|
||||
key: API key (used as username)
|
||||
secret: API secret (used as password)
|
||||
base_url: Base URL for the Highspot API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
backoff_factor: Backoff factor for retries
|
||||
status_forcelist: HTTP status codes to retry on
|
||||
"""
|
||||
if not key or not secret:
|
||||
raise ValueError("API key and secret are required")
|
||||
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retry logic
|
||||
self.session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=max_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
status_forcelist=status_forcelist or [429, 500, 502, 503, 504],
|
||||
allowed_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Set up authentication
|
||||
self._setup_auth()
|
||||
|
||||
def _setup_auth(self) -> None:
|
||||
"""Set up basic authentication for the session."""
|
||||
auth = f"{self.key}:{self.secret}"
|
||||
encoded_auth = base64.b64encode(auth.encode()).decode()
|
||||
self.session.headers.update(
|
||||
{
|
||||
"Authorization": f"Basic {encoded_auth}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make a request to the Highspot API.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
endpoint: API endpoint
|
||||
params: URL parameters
|
||||
data: Form data
|
||||
json_data: JSON data
|
||||
headers: Additional headers
|
||||
|
||||
Returns:
|
||||
API response as a dictionary
|
||||
|
||||
Raises:
|
||||
HighspotClientError: On API errors
|
||||
HighspotAuthenticationError: On authentication errors
|
||||
HighspotRateLimitError: On rate limiting
|
||||
requests.exceptions.RequestException: On request failures
|
||||
"""
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
request_headers = {}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
try:
|
||||
logger.debug(f"Making {method} request to {url}")
|
||||
response = self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=request_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.content and response.content.strip():
|
||||
return response.json()
|
||||
return {}
|
||||
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code
|
||||
error_msg = str(e)
|
||||
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
if isinstance(error_data, dict):
|
||||
error_msg = error_data.get("message", str(e))
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
if status_code == 401:
|
||||
raise HighspotAuthenticationError(f"Authentication failed: {error_msg}")
|
||||
elif status_code == 429:
|
||||
retry_after = e.response.headers.get("Retry-After")
|
||||
raise HighspotRateLimitError(
|
||||
f"Rate limit exceeded: {error_msg}", retry_after=retry_after
|
||||
)
|
||||
else:
|
||||
raise HighspotClientError(
|
||||
f"API error {status_code}: {error_msg}", status_code=status_code
|
||||
)
|
||||
|
||||
except Timeout:
|
||||
raise HighspotClientError("Request timed out")
|
||||
except RequestException as e:
|
||||
raise HighspotClientError(f"Request failed: {str(e)}")
|
||||
|
||||
def get_spots(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all available spots.
|
||||
|
||||
Returns:
|
||||
List of spots with their names and IDs
|
||||
"""
|
||||
params = {"right": "view"}
|
||||
response = self._make_request("GET", "spots", params=params)
|
||||
logger.info(f"Received {response} spots")
|
||||
total_counts = response.get("counts_total")
|
||||
# Fix comparison to handle None value
|
||||
if total_counts is not None and total_counts > 0:
|
||||
return response.get("collection", [])
|
||||
return []
|
||||
|
||||
def get_spot(self, spot_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get details for a specific spot.
|
||||
|
||||
Args:
|
||||
spot_id: ID of the spot
|
||||
|
||||
Returns:
|
||||
Spot details
|
||||
"""
|
||||
if not spot_id:
|
||||
raise ValueError("spot_id is required")
|
||||
return self._make_request("GET", f"spots/{spot_id}")
|
||||
|
||||
def get_spot_items(
|
||||
self, spot_id: str, offset: int = 0, page_size: int = 100
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get items in a specific spot.
|
||||
|
||||
Args:
|
||||
spot_id: ID of the spot
|
||||
offset: offset number
|
||||
page_size: Number of items per page
|
||||
|
||||
Returns:
|
||||
Items in the spot
|
||||
"""
|
||||
if not spot_id:
|
||||
raise ValueError("spot_id is required")
|
||||
|
||||
params = {"spot": spot_id, "start": offset, "limit": page_size}
|
||||
return self._make_request("GET", "items", params=params)
|
||||
|
||||
def get_item(self, item_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get details for a specific item.
|
||||
|
||||
Args:
|
||||
item_id: ID of the item
|
||||
|
||||
Returns:
|
||||
Item details
|
||||
"""
|
||||
if not item_id:
|
||||
raise ValueError("item_id is required")
|
||||
return self._make_request("GET", f"items/{item_id}")
|
||||
|
||||
def get_item_content(self, item_id: str) -> bytes:
|
||||
"""
|
||||
Get the raw content of an item.
|
||||
|
||||
Args:
|
||||
item_id: ID of the item
|
||||
|
||||
Returns:
|
||||
Raw content bytes
|
||||
"""
|
||||
if not item_id:
|
||||
raise ValueError("item_id is required")
|
||||
|
||||
url = urljoin(self.base_url, f"items/{item_id}/content")
|
||||
response = self.session.get(url, timeout=self.timeout)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""
|
||||
Check if the API is accessible and credentials are valid.
|
||||
|
||||
Returns:
|
||||
True if API is accessible, False otherwise
|
||||
"""
|
||||
try:
|
||||
self._make_request("GET", "spots", params={"limit": 1})
|
||||
return True
|
||||
except (HighspotClientError, HighspotAuthenticationError):
|
||||
return False
|
||||
@@ -1,431 +0,0 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.highspot.client import HighspotClient
|
||||
from onyx.connectors.highspot.client import HighspotClientError
|
||||
from onyx.connectors.highspot.utils import scrape_url_content
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""
|
||||
Connector for loading data from Highspot.
|
||||
|
||||
Retrieves content from specified spots using the Highspot API.
|
||||
If no spots are specified, retrieves content from all available spots.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spot_names: List[str] = [],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
"""
|
||||
Initialize the Highspot connector.
|
||||
|
||||
Args:
|
||||
spot_names: List of spot names to retrieve content from (if empty, gets all spots)
|
||||
batch_size: Number of items to retrieve in each batch
|
||||
"""
|
||||
self.spot_names = spot_names
|
||||
self.batch_size = batch_size
|
||||
self._client: Optional[HighspotClient] = None
|
||||
self._spot_id_map: Dict[str, str] = {} # Maps spot names to spot IDs
|
||||
self._all_spots_fetched = False
|
||||
self.highspot_url: Optional[str] = None
|
||||
self.key: Optional[str] = None
|
||||
self.secret: Optional[str] = None
|
||||
|
||||
@property
|
||||
def client(self) -> HighspotClient:
|
||||
if self._client is None:
|
||||
if not self.key or not self.secret:
|
||||
raise ConnectorMissingCredentialError("Highspot")
|
||||
# Ensure highspot_url is a string, use default if None
|
||||
base_url = (
|
||||
self.highspot_url
|
||||
if self.highspot_url is not None
|
||||
else HighspotClient.BASE_URL
|
||||
)
|
||||
self._client = HighspotClient(self.key, self.secret, base_url=base_url)
|
||||
return self._client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
logger.info("Loading Highspot credentials")
|
||||
self.highspot_url = credentials.get("highspot_url")
|
||||
self.key = credentials.get("highspot_key")
|
||||
self.secret = credentials.get("highspot_secret")
|
||||
return None
|
||||
|
||||
def _populate_spot_id_map(self) -> None:
|
||||
"""
|
||||
Populate the spot ID map with all available spots.
|
||||
Keys are stored as lowercase for case-insensitive lookups.
|
||||
"""
|
||||
spots = self.client.get_spots()
|
||||
for spot in spots:
|
||||
if "title" in spot and "id" in spot:
|
||||
spot_name = spot["title"]
|
||||
self._spot_id_map[spot_name.lower()] = spot["id"]
|
||||
|
||||
self._all_spots_fetched = True
|
||||
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
|
||||
|
||||
def _get_all_spot_names(self) -> List[str]:
|
||||
"""
|
||||
Retrieve all available spot names.
|
||||
|
||||
Returns:
|
||||
List of all spot names
|
||||
"""
|
||||
if not self._all_spots_fetched:
|
||||
self._populate_spot_id_map()
|
||||
|
||||
return [spot_name for spot_name in self._spot_id_map.keys()]
|
||||
|
||||
def _get_spot_id_from_name(self, spot_name: str) -> str:
|
||||
"""
|
||||
Get spot ID from a spot name.
|
||||
|
||||
Args:
|
||||
spot_name: Name of the spot
|
||||
|
||||
Returns:
|
||||
ID of the spot
|
||||
|
||||
Raises:
|
||||
ValueError: If spot name is not found
|
||||
"""
|
||||
if not self._all_spots_fetched:
|
||||
self._populate_spot_id_map()
|
||||
|
||||
spot_name_lower = spot_name.lower()
|
||||
if spot_name_lower not in self._spot_id_map:
|
||||
raise ValueError(f"Spot '{spot_name}' not found")
|
||||
|
||||
return self._spot_id_map[spot_name_lower]
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Load content from configured spots in Highspot.
|
||||
If no spots are configured, loads from all spots.
|
||||
|
||||
Yields:
|
||||
Batches of Document objects
|
||||
"""
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Poll Highspot for content updated since the start time.
|
||||
|
||||
Args:
|
||||
start: Start time as seconds since Unix epoch
|
||||
end: End time as seconds since Unix epoch
|
||||
|
||||
Yields:
|
||||
Batches of Document objects
|
||||
"""
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
# If no spots specified, get all spots
|
||||
spot_names_to_process = self.spot_names
|
||||
if not spot_names_to_process:
|
||||
spot_names_to_process = self._get_all_spot_names()
|
||||
logger.info(
|
||||
f"No spots specified, using all {len(spot_names_to_process)} available spots"
|
||||
)
|
||||
|
||||
for spot_name in spot_names_to_process:
|
||||
try:
|
||||
spot_id = self._get_spot_id_from_name(spot_name)
|
||||
if spot_id is None:
|
||||
logger.warning(f"Spot ID not found for spot {spot_name}")
|
||||
continue
|
||||
offset = 0
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
logger.info(
|
||||
f"Retrieving items from spot {spot_name}, offset {offset}"
|
||||
)
|
||||
response = self.client.get_spot_items(
|
||||
spot_id=spot_id, offset=offset, page_size=self.batch_size
|
||||
)
|
||||
items = response.get("collection", [])
|
||||
logger.info(f"Received Items: {items}")
|
||||
if not items:
|
||||
has_more = False
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
try:
|
||||
item_id = item.get("id")
|
||||
if not item_id:
|
||||
logger.warning("Item without ID found, skipping")
|
||||
continue
|
||||
|
||||
item_details = self.client.get_item(item_id)
|
||||
if not item_details:
|
||||
logger.warning(
|
||||
f"Item {item_id} details not found, skipping"
|
||||
)
|
||||
continue
|
||||
# Apply time filter if specified
|
||||
if start or end:
|
||||
updated_at = item_details.get("date_updated")
|
||||
if updated_at:
|
||||
# Convert to datetime for comparison
|
||||
try:
|
||||
updated_time = datetime.fromisoformat(
|
||||
updated_at.replace("Z", "+00:00")
|
||||
)
|
||||
if (
|
||||
start and updated_time.timestamp() < start
|
||||
) or (end and updated_time.timestamp() > end):
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
# Skip if date cannot be parsed
|
||||
logger.warning(
|
||||
f"Invalid date format for item {item_id}: {updated_at}"
|
||||
)
|
||||
continue
|
||||
|
||||
content = self._get_item_content(item_details)
|
||||
title = item_details.get("title", "")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=f"HIGHSPOT_{item_id}",
|
||||
sections=[
|
||||
TextSection(
|
||||
link=item_details.get(
|
||||
"url",
|
||||
f"https://www.highspot.com/items/{item_id}",
|
||||
),
|
||||
text=content,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.HIGHSPOT,
|
||||
semantic_identifier=title,
|
||||
metadata={
|
||||
"spot_name": spot_name,
|
||||
"type": item_details.get("content_type", ""),
|
||||
"created_at": item_details.get(
|
||||
"date_added", ""
|
||||
),
|
||||
"author": item_details.get("author", ""),
|
||||
"language": item_details.get("language", ""),
|
||||
"can_download": str(
|
||||
item_details.get("can_download", False)
|
||||
),
|
||||
},
|
||||
doc_updated_at=item_details.get("date_updated"),
|
||||
)
|
||||
)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
except HighspotClientError as e:
|
||||
item_id = "ID" if not item_id else item_id
|
||||
logger.error(f"Error retrieving item {item_id}: {str(e)}")
|
||||
|
||||
has_more = len(items) >= self.batch_size
|
||||
offset += self.batch_size
|
||||
|
||||
except (HighspotClientError, ValueError) as e:
|
||||
logger.error(f"Error processing spot {spot_name}: {str(e)}")
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def _get_item_content(self, item_details: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Get the text content of an item.
|
||||
|
||||
Args:
|
||||
item_details: Item details from the API
|
||||
|
||||
Returns:
|
||||
Text content of the item
|
||||
"""
|
||||
item_id = item_details.get("id", "")
|
||||
content_name = item_details.get("content_name", "")
|
||||
is_valid_format = content_name and "." in content_name
|
||||
file_extension = content_name.split(".")[-1].lower() if is_valid_format else ""
|
||||
file_extension = "." + file_extension if file_extension else ""
|
||||
can_download = item_details.get("can_download", False)
|
||||
content_type = item_details.get("content_type", "")
|
||||
|
||||
# Extract title and description once at the beginning
|
||||
title, description = self._extract_title_and_description(item_details)
|
||||
default_content = f"{title}\n{description}"
|
||||
logger.info(f"Processing item {item_id} with extension {file_extension}")
|
||||
|
||||
try:
|
||||
if content_type == "WebLink":
|
||||
url = item_details.get("url")
|
||||
if not url:
|
||||
return default_content
|
||||
content = scrape_url_content(url, True)
|
||||
return content if content else default_content
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and file_extension in VALID_FILE_EXTENSIONS
|
||||
and can_download
|
||||
):
|
||||
# For documents, try to get the text content
|
||||
if not item_id: # Ensure item_id is defined
|
||||
return default_content
|
||||
|
||||
content_response = self.client.get_item_content(item_id)
|
||||
# Process and extract text from binary content based on type
|
||||
if content_response:
|
||||
text_content = extract_file_text(
|
||||
BytesIO(content_response), content_name
|
||||
)
|
||||
return text_content
|
||||
return default_content
|
||||
|
||||
else:
|
||||
return default_content
|
||||
|
||||
except HighspotClientError as e:
|
||||
# Use item_id safely in the warning message
|
||||
error_context = f"item {item_id}" if item_id else "item"
|
||||
logger.warning(f"Could not retrieve content for {error_context}: {str(e)}")
|
||||
return ""
|
||||
|
||||
def _extract_title_and_description(
|
||||
self, item_details: Dict[str, Any]
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Extract the title and description from item details.
|
||||
|
||||
Args:
|
||||
item_details: Item details from the API
|
||||
|
||||
Returns:
|
||||
Tuple of title and description
|
||||
"""
|
||||
title = item_details.get("title", "")
|
||||
description = item_details.get("description", "")
|
||||
return title, description
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Retrieve all document IDs from the configured spots.
|
||||
If no spots are configured, retrieves from all spots.
|
||||
|
||||
Args:
|
||||
start: Optional start time filter
|
||||
end: Optional end time filter
|
||||
callback: Optional indexing heartbeat callback
|
||||
|
||||
Yields:
|
||||
Batches of SlimDocument objects
|
||||
"""
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
|
||||
# If no spots specified, get all spots
|
||||
spot_names_to_process = self.spot_names
|
||||
if not spot_names_to_process:
|
||||
spot_names_to_process = self._get_all_spot_names()
|
||||
logger.info(
|
||||
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
|
||||
)
|
||||
|
||||
for spot_name in spot_names_to_process:
|
||||
try:
|
||||
spot_id = self._get_spot_id_from_name(spot_name)
|
||||
offset = 0
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
logger.info(
|
||||
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
|
||||
)
|
||||
response = self.client.get_spot_items(
|
||||
spot_id=spot_id, offset=offset, page_size=self.batch_size
|
||||
)
|
||||
|
||||
items = response.get("collection", [])
|
||||
if not items:
|
||||
has_more = False
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
item_id = item.get("id")
|
||||
if not item_id:
|
||||
continue
|
||||
|
||||
slim_doc_batch.append(SlimDocument(id=f"HIGHSPOT_{item_id}"))
|
||||
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
|
||||
has_more = len(items) >= self.batch_size
|
||||
offset += self.batch_size
|
||||
|
||||
except (HighspotClientError, ValueError) as e:
|
||||
logger.error(
|
||||
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
|
||||
)
|
||||
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_credentials(self) -> bool:
|
||||
"""
|
||||
Validate that the provided credentials can access the Highspot API.
|
||||
|
||||
Returns:
|
||||
True if credentials are valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
return self.client.health_check()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate credentials: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
spot_names: List[str] = []
|
||||
connector = HighspotConnector(spot_names)
|
||||
credentials = {"highspot_key": "", "highspot_secret": ""}
|
||||
connector.load_credentials(credentials=credentials)
|
||||
for doc in connector.load_from_state():
|
||||
print(doc)
|
||||
@@ -1,122 +0,0 @@
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Constants
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
DEFAULT_TIMEOUT = 60000 # 60 seconds
|
||||
|
||||
|
||||
def scrape_url_content(
|
||||
url: str, scroll_before_scraping: bool = False, timeout_ms: int = DEFAULT_TIMEOUT
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Scrapes content from a given URL and returns the cleaned text.
|
||||
|
||||
Args:
|
||||
url: The URL to scrape
|
||||
scroll_before_scraping: Whether to scroll through the page to load lazy content
|
||||
timeout_ms: Timeout in milliseconds for page navigation and loading
|
||||
|
||||
Returns:
|
||||
The cleaned text content of the page or None if scraping fails
|
||||
"""
|
||||
playwright = None
|
||||
browser = None
|
||||
try:
|
||||
validate_url(url)
|
||||
playwright = sync_playwright().start()
|
||||
browser = playwright.chromium.launch(headless=True)
|
||||
context = browser.new_context()
|
||||
page = context.new_page()
|
||||
|
||||
logger.info(f"Navigating to URL: {url}")
|
||||
try:
|
||||
page.goto(url, timeout=timeout_ms)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to navigate to {url}: {str(e)}")
|
||||
return None
|
||||
|
||||
if scroll_before_scraping:
|
||||
logger.debug("Scrolling page to load lazy content")
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=timeout_ms)
|
||||
except Exception as e:
|
||||
logger.warning(f"Network idle wait timed out: {str(e)}")
|
||||
break
|
||||
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
if new_height == previous_height:
|
||||
break
|
||||
previous_height = new_height
|
||||
scroll_attempts += 1
|
||||
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
parsed_html = web_html_cleanup(soup)
|
||||
|
||||
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
|
||||
logger.debug("JavaScript disabled message detected, checking iframes")
|
||||
try:
|
||||
iframe_count = page.frame_locator("iframe").locator("html").count()
|
||||
if iframe_count > 0:
|
||||
iframe_texts = (
|
||||
page.frame_locator("iframe").locator("html").all_inner_texts()
|
||||
)
|
||||
iframe_content = "\n".join(iframe_texts)
|
||||
|
||||
if len(parsed_html.cleaned_text) < 700:
|
||||
parsed_html.cleaned_text = iframe_content
|
||||
else:
|
||||
parsed_html.cleaned_text += "\n" + iframe_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing iframes: {str(e)}")
|
||||
|
||||
return parsed_html.cleaned_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scraping URL {url}: {str(e)}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
if browser:
|
||||
try:
|
||||
browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing browser: {str(e)}")
|
||||
if playwright:
|
||||
try:
|
||||
playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error stopping playwright: {str(e)}")
|
||||
|
||||
|
||||
def validate_url(url: str) -> None:
|
||||
"""
|
||||
Validates that a URL is properly formatted.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is not valid
|
||||
"""
|
||||
parse = urlparse(url)
|
||||
if parse.scheme != "http" and parse.scheme != "https":
|
||||
raise ValueError("URL must be of scheme https?://")
|
||||
|
||||
if not parse.hostname:
|
||||
raise ValueError("URL must include a hostname")
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -20,11 +19,10 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
class BaseConnector(abc.ABC):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
@@ -59,14 +57,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
@@ -84,8 +74,6 @@ class PollConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Slim connectors can retrieve just the ids and
|
||||
# permission syncing information for connected documents
|
||||
class SlimConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def retrieve_all_slim_documents(
|
||||
@@ -198,17 +186,14 @@ class EventConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector[CT]):
|
||||
class CheckpointConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: CT,
|
||||
) -> CheckpointOutput[CT]:
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
@@ -229,12 +214,3 @@ class CheckpointConnector(BaseConnector[CT]):
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
|
||||
"""Validate the checkpoint json and return the checkpoint object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
@@ -16,18 +15,14 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class MockConnectorCheckpoint(ConnectorCheckpoint):
|
||||
last_document_id: str | None = None
|
||||
|
||||
|
||||
class SingleConnectorYield(BaseModel):
|
||||
documents: list[Document]
|
||||
checkpoint: MockConnectorCheckpoint
|
||||
checkpoint: ConnectorCheckpoint
|
||||
failures: list[ConnectorFailure]
|
||||
unhandled_exception: str | None = None
|
||||
|
||||
|
||||
class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
|
||||
class MockConnector(CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
mock_server_host: str,
|
||||
@@ -53,7 +48,7 @@ class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
|
||||
def _get_mock_server_url(self, endpoint: str) -> str:
|
||||
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
|
||||
|
||||
def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None:
|
||||
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
|
||||
response = self.client.post(
|
||||
self._get_mock_server_url("add-checkpoint"),
|
||||
json=checkpoint.model_dump(mode="json"),
|
||||
@@ -64,8 +59,8 @@ class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: MockConnectorCheckpoint,
|
||||
) -> CheckpointOutput[MockConnectorCheckpoint]:
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
if self.connector_yields is None:
|
||||
raise ValueError("No connector yields configured")
|
||||
|
||||
@@ -89,13 +84,3 @@ class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
|
||||
yield failure
|
||||
|
||||
return current_yield.checkpoint
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> MockConnectorCheckpoint:
|
||||
return MockConnectorCheckpoint(
|
||||
has_more=True,
|
||||
last_document_id=None,
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint:
|
||||
return MockConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -231,16 +232,21 @@ class IndexAttemptMetadata(BaseModel):
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the checkpoint, with truncation for large checkpoint content."""
|
||||
MAX_CHECKPOINT_CONTENT_CHARS = 1000
|
||||
|
||||
content_str = self.model_dump_json()
|
||||
content_str = json.dumps(self.checkpoint_content)
|
||||
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
|
||||
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
|
||||
return content_str
|
||||
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from onyx.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
@@ -25,7 +25,6 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -39,7 +38,8 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
|
||||
|
||||
|
||||
class NotionPage(BaseModel):
|
||||
@dataclass
|
||||
class NotionPage:
|
||||
"""Represents a Notion Page object"""
|
||||
|
||||
id: str
|
||||
@@ -49,10 +49,17 @@ class NotionPage(BaseModel):
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
|
||||
database_name: str | None = None # Only applicable to the database type page (wiki)
|
||||
database_name: str | None # Only applicable to the database type page (wiki)
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
names = set([f.name for f in fields(self)])
|
||||
for k, v in kwargs.items():
|
||||
if k in names:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class NotionBlock(BaseModel):
|
||||
@dataclass
|
||||
class NotionBlock:
|
||||
"""Represents a Notion Block object"""
|
||||
|
||||
id: str # Used for the URL
|
||||
@@ -62,13 +69,20 @@ class NotionBlock(BaseModel):
|
||||
prefix: str
|
||||
|
||||
|
||||
class NotionSearchResponse(BaseModel):
|
||||
@dataclass
|
||||
class NotionSearchResponse:
|
||||
"""Represents the response from the Notion Search API"""
|
||||
|
||||
results: list[dict[str, Any]]
|
||||
next_cursor: Optional[str]
|
||||
has_more: bool = False
|
||||
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
names = set([f.name for f in fields(self)])
|
||||
for k, v in kwargs.items():
|
||||
if k in names:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class NotionConnector(LoadConnector, PollConnector):
|
||||
"""Notion Page connector that reads all Notion pages
|
||||
@@ -81,7 +95,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
|
||||
root_page_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
@@ -450,53 +464,23 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
page_blocks, child_page_ids = self._read_blocks(page.id)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
|
||||
# okay to mark here since there's no way for this to not succeed
|
||||
# without a critical failure
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logger.warning(
|
||||
f"No blocks OR title found for page with ID '{page.id}'. Skipping."
|
||||
)
|
||||
continue
|
||||
continue
|
||||
|
||||
logger.debug(f"No blocks found for page with ID '{page.id}'")
|
||||
"""
|
||||
Something like:
|
||||
|
||||
TITLE
|
||||
|
||||
PROP1: PROP1_VALUE
|
||||
PROP2: PROP2_VALUE
|
||||
"""
|
||||
text = page_title
|
||||
if page.properties:
|
||||
text += "\n\n" + "\n".join(
|
||||
[f"{key}: {value}" for key, value in page.properties.items()]
|
||||
)
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}",
|
||||
text=text,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sections = [
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
]
|
||||
page_title = (
|
||||
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
|
||||
)
|
||||
|
||||
yield (
|
||||
Document(
|
||||
id=page.id,
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=[
|
||||
TextSection(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
for block in page_blocks
|
||||
],
|
||||
source=DocumentSource.NOTION,
|
||||
semantic_identifier=page_title,
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
@@ -16,16 +15,14 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
@@ -45,112 +42,121 @@ _JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
|
||||
def _perform_jql_search(
|
||||
def _paginate_jql_search(
|
||||
jira_client: JIRA,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
start = 0
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
yield issue
|
||||
else:
|
||||
raise Exception(f"Found Jira object not of type Issue: {issue}")
|
||||
|
||||
if len(issues) < max_results:
|
||||
break
|
||||
|
||||
start += max_results
|
||||
|
||||
|
||||
def process_jira_issue(
|
||||
def fetch_jira_issues_batch(
|
||||
jira_client: JIRA,
|
||||
issue: Issue,
|
||||
jql: str,
|
||||
batch_size: int,
|
||||
comment_email_blacklist: tuple[str, ...] = (),
|
||||
labels_to_skip: set[str] | None = None,
|
||||
) -> Document | None:
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
return None
|
||||
) -> Iterable[Document]:
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=jira_client,
|
||||
jql=jql,
|
||||
max_results=batch_size,
|
||||
):
|
||||
if labels_to_skip:
|
||||
if any(label in issue.fields.labels for label in labels_to_skip):
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it has a label to skip. Found "
|
||||
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
|
||||
)
|
||||
continue
|
||||
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
description = (
|
||||
issue.fields.description
|
||||
if JIRA_API_VERSION == "2"
|
||||
else extract_text_from_adf(issue.raw["fields"]["description"])
|
||||
)
|
||||
comments = get_comment_strs(
|
||||
issue=issue,
|
||||
comment_email_blacklist=comment_email_blacklist,
|
||||
)
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
return None
|
||||
|
||||
page_url = build_jira_url(jira_client, issue.key)
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {issue.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, "creator")
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["labels"] = labels
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, "assignee")
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
metadata_dict = {}
|
||||
if priority := best_effort_get_field_from_issue(issue, "priority"):
|
||||
metadata_dict["priority"] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, "status"):
|
||||
metadata_dict["status"] = status.name
|
||||
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
if labels := best_effort_get_field_from_issue(issue, "labels"):
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
metadata=metadata_dict,
|
||||
)
|
||||
|
||||
|
||||
class JiraConnectorCheckpoint(ConnectorCheckpoint):
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector):
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
@@ -194,10 +200,33 @@ class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector)
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_jql_query(
|
||||
def _get_jql_query(self) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set"""
|
||||
if self.jira_project:
|
||||
return f"project = {self.quoted_jira_project}"
|
||||
return "" # Empty string means all accessible projects
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = self._get_jql_query()
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
yield document_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set and time range"""
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
@@ -205,61 +234,25 @@ class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector)
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
base_jql = self._get_jql_query()
|
||||
jql = (
|
||||
f"{base_jql} AND " if base_jql else ""
|
||||
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
|
||||
if self.jira_project:
|
||||
base_jql = f"project = {self.quoted_jira_project}"
|
||||
return f"{base_jql} AND {time_jql}"
|
||||
|
||||
return time_jql
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
current_offset = starting_offset
|
||||
|
||||
for issue in _perform_jql_search(
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_FULL_PAGE_SIZE,
|
||||
batch_size=_JIRA_FULL_PAGE_SIZE,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
issue_key = issue.key
|
||||
try:
|
||||
if document := process_jira_issue(
|
||||
jira_client=self.jira_client,
|
||||
issue=issue,
|
||||
comment_email_blacklist=self.comment_email_blacklist,
|
||||
labels_to_skip=self.labels_to_skip,
|
||||
):
|
||||
yield document
|
||||
document_batch.append(doc)
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield document_batch
|
||||
document_batch = []
|
||||
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_jira_url(self.jira_client, issue_key),
|
||||
),
|
||||
failure_message=f"Failed to process Jira issue: {str(e)}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
current_offset += 1
|
||||
|
||||
# Update checkpoint
|
||||
checkpoint = JiraConnectorCheckpoint(
|
||||
offset=current_offset,
|
||||
# if we didn't retrieve a full batch, we're done
|
||||
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
|
||||
)
|
||||
return checkpoint
|
||||
yield document_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -267,13 +260,12 @@ class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector)
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = self._get_jql_query(start or 0, end or float("inf"))
|
||||
jql = self._get_jql_query()
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _perform_jql_search(
|
||||
for issue in _paginate_jql_search(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=0,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields="key",
|
||||
):
|
||||
@@ -342,16 +334,6 @@ class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector)
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint:
|
||||
return JiraConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint:
|
||||
return JiraConnectorCheckpoint(
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -368,7 +350,5 @@ if __name__ == "__main__":
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, float("inf"), JiraConnectorCheckpoint(has_more=True)
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -10,15 +10,13 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import SLACK_NUM_THREADS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -58,8 +56,8 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpoint(ConnectorCheckpoint):
|
||||
channel_ids: list[str] | None
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
@@ -414,8 +412,8 @@ def _get_all_doc_ids(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
@@ -423,27 +421,18 @@ def _get_all_doc_ids(
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
message_ts_set.add(message["ts"])
|
||||
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
channel_metadata_list: list[SlimDocument] = []
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
class ProcessedSlackMessage(BaseModel):
|
||||
doc: Document | None
|
||||
# if the message is part of a thread, this is the thread_ts
|
||||
# otherwise, this is the message_ts. Either way, will be a unique identifier.
|
||||
# In the future, if the message becomes a thread, then the thread_ts
|
||||
# will be set to the message_ts.
|
||||
thread_or_message_ts: str
|
||||
failure: ConnectorFailure | None
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
def _process_message(
|
||||
@@ -454,9 +443,8 @@ def _process_message(
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
@@ -472,18 +460,16 @@ def _process_message(
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
failure=ConnectorFailure(
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
@@ -493,9 +479,7 @@ def _process_message(
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
FAST_TIMEOUT = 1
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -503,14 +487,12 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
self.fast_client: WebClient | None = None
|
||||
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
@@ -518,10 +500,6 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
|
||||
@@ -545,8 +523,8 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: SlackCheckpoint,
|
||||
) -> CheckpointOutput[SlackCheckpoint]:
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
@@ -562,36 +540,49 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint.channel_ids is None:
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
|
||||
if len(filtered_channels) == 0:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
checkpoint.current_channel = filtered_channels[0]
|
||||
checkpoint.has_more = True
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint.channel_ids
|
||||
channel = checkpoint.current_channel
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not set in checkpoint")
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint.seen_thread_ts)
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
@@ -602,8 +593,8 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@@ -621,46 +612,46 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
processed_slack_message = future.result()
|
||||
doc = processed_slack_message.doc
|
||||
thread_or_message_ts = processed_slack_message.thread_or_message_ts
|
||||
failure = processed_slack_message.failure
|
||||
doc, thread_ts, failures = future.result()
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_or_message_ts not in seen_thread_ts:
|
||||
if thread_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
assert (
|
||||
thread_or_message_ts
|
||||
), "found non-None doc with None thread_or_message_ts"
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif failure:
|
||||
yield failure
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint.current_channel = channel
|
||||
checkpoint_content["current_channel"] = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id not in checkpoint.channel_completion_map
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint.current_channel = new_channel
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
else:
|
||||
checkpoint.current_channel = None
|
||||
checkpoint_content["current_channel"] = None
|
||||
|
||||
checkpoint.has_more = checkpoint.current_channel is not None
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
@@ -684,12 +675,12 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
2. Ensure the bot has enough scope to list channels.
|
||||
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
|
||||
"""
|
||||
if self.fast_client is None:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
|
||||
try:
|
||||
# 1) Validate connection to workspace
|
||||
auth_response = self.fast_client.auth_test()
|
||||
auth_response = self.client.auth_test()
|
||||
if not auth_response.get("ok", False):
|
||||
error_msg = auth_response.get(
|
||||
"error", "Unknown error from Slack auth_test"
|
||||
@@ -697,7 +688,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||
|
||||
# 2) Minimal test to confirm listing channels works
|
||||
test_resp = self.fast_client.conversations_list(
|
||||
test_resp = self.client.conversations_list(
|
||||
limit=1, types=["public_channel"]
|
||||
)
|
||||
if not test_resp.get("ok", False):
|
||||
@@ -715,41 +706,29 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
)
|
||||
|
||||
# 3) If channels are specified and regex is not enabled, verify each is accessible
|
||||
# NOTE: removed this for now since it may be too slow for large workspaces which may
|
||||
# have some automations which create a lot of channels (100k+)
|
||||
if self.channels and not self.channel_regex_enabled:
|
||||
accessible_channels = get_channels(
|
||||
client=self.client,
|
||||
exclude_archived=True,
|
||||
get_public=True,
|
||||
get_private=True,
|
||||
)
|
||||
# For quick lookups by name or ID, build a map:
|
||||
accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
# if self.channels and not self.channel_regex_enabled:
|
||||
# accessible_channels = get_channels(
|
||||
# client=self.fast_client,
|
||||
# exclude_archived=True,
|
||||
# get_public=True,
|
||||
# get_private=True,
|
||||
# )
|
||||
# # For quick lookups by name or ID, build a map:
|
||||
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
# for user_channel in self.channels:
|
||||
# if (
|
||||
# user_channel not in accessible_channel_names
|
||||
# and user_channel not in accessible_channel_ids
|
||||
# ):
|
||||
# raise ConnectorValidationError(
|
||||
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
# )
|
||||
for user_channel in self.channels:
|
||||
if (
|
||||
user_channel not in accessible_channel_names
|
||||
and user_channel not in accessible_channel_ids
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
if slack_error == "ratelimited":
|
||||
# Handle rate limiting specifically
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.warning(
|
||||
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
|
||||
"Proceeding with validation, but be aware that connector operations might be throttled."
|
||||
)
|
||||
# Continue validation without failing - the connector is likely valid but just rate limited
|
||||
return
|
||||
elif slack_error == "missing_scope":
|
||||
if slack_error == "missing_scope":
|
||||
raise InsufficientPermissionsError(
|
||||
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||
@@ -772,20 +751,6 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> SlackCheckpoint:
|
||||
return SlackCheckpoint(
|
||||
channel_ids=None,
|
||||
channel_completion_map={},
|
||||
current_channel=None,
|
||||
seen_thread_ts=[],
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint:
|
||||
return SlackCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -800,11 +765,9 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
gen = connector.load_from_checkpoint(
|
||||
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
|
||||
)
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -35,7 +26,6 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -63,22 +53,10 @@ class ZendeskClient:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
elif (
|
||||
response.status_code == 403
|
||||
and response.json().get("error") == "SupportProductInactive"
|
||||
):
|
||||
return response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
meta: dict[str, Any]
|
||||
has_more: bool
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
@@ -104,9 +82,11 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
)
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
@@ -118,30 +98,10 @@ def _get_articles(
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_article_page(
|
||||
client: ZendeskClient,
|
||||
start_time: int | None = None,
|
||||
after_cursor: str | None = None,
|
||||
page_size: int = MAX_PAGE_SIZE,
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if after_cursor is not None:
|
||||
params["page[after]"] = after_cursor
|
||||
|
||||
data = client.make_request("help_center/articles", params)
|
||||
return ZendeskPageResponse(
|
||||
data=data["articles"],
|
||||
meta=data["meta"],
|
||||
has_more=bool(data["meta"].get("has_more", False)),
|
||||
)
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time or 0}
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
@@ -154,27 +114,6 @@ def _get_tickets(
|
||||
break
|
||||
|
||||
|
||||
# TODO: maybe these don't need to be their own functions?
|
||||
def _get_tickets_page(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
|
||||
# in my local testing with very few tickets. We'll look into it if this becomes an
|
||||
# issue in larger deployments
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
if data.get("error") == "SupportProductInactive":
|
||||
raise ValueError(
|
||||
"Zendesk Support Product is not active for this account, No tickets to index"
|
||||
)
|
||||
return ZendeskPageResponse(
|
||||
data=data["tickets"],
|
||||
meta={"end_time": data["end_time"]},
|
||||
has_more=not bool(data.get("end_of_stream", False)),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
@@ -339,22 +278,13 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# We use cursor-based paginated retrieval for articles
|
||||
after_cursor_articles: str | None
|
||||
|
||||
# We use timestamp-based paginated retrieval for tickets
|
||||
next_start_time_tickets: int | None
|
||||
|
||||
cached_author_map: dict[str, BasicExpertInfo] | None
|
||||
cached_content_tags: dict[str, str] | None
|
||||
|
||||
|
||||
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
@@ -374,50 +304,33 @@ class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckp
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
if checkpoint.cached_content_tags is None:
|
||||
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
|
||||
return checkpoint # save the content tags to the checkpoint
|
||||
self.content_tags = checkpoint.cached_content_tags
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
|
||||
if self.content_type == "articles":
|
||||
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
|
||||
return checkpoint
|
||||
yield from self._poll_articles(start)
|
||||
elif self.content_type == "tickets":
|
||||
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
|
||||
return checkpoint
|
||||
yield from self._poll_tickets(start)
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _retrieve_articles(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
after_cursor = checkpoint.after_cursor_articles
|
||||
doc_batch: list[Document] = []
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
response = _get_article_page(
|
||||
self.client,
|
||||
start_time=int(start) if start else None,
|
||||
after_cursor=after_cursor,
|
||||
)
|
||||
articles = response.data
|
||||
has_more = response.has_more
|
||||
after_cursor = response.meta.get("after_cursor")
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
doc_batch = []
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
@@ -429,109 +342,66 @@ class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckp
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
new_author_map, document = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{article.get('id')}",
|
||||
document_link=article.get("html_url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(document)
|
||||
doc_batch.append(documents)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
# Sometimes no documents are retrieved, but the cursor
|
||||
# is still updated so the connector makes progress.
|
||||
yield from doc_batch
|
||||
checkpoint.after_cursor_articles = after_cursor
|
||||
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def _retrieve_tickets(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
|
||||
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
|
||||
tickets = ticket_response.data
|
||||
has_more = ticket_response.has_more
|
||||
next_start_time = ticket_response.meta["end_time"]
|
||||
for ticket in tickets:
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
try:
|
||||
new_author_map, document = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{ticket.get('id')}",
|
||||
document_link=ticket.get("url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
yield from doc_batch
|
||||
checkpoint.next_start_time_tickets = next_start_time
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -571,51 +441,10 @@ class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckp
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
try:
|
||||
_get_article_page(self.client, start_time=0)
|
||||
except HTTPError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.response.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.response.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
elif e.response.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Zendesk resource not found (HTTP 404)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint(
|
||||
after_cursor_articles=None,
|
||||
next_start_time_tickets=None,
|
||||
cached_author_map=None,
|
||||
cached_content_tags=None,
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
connector.load_credentials(
|
||||
@@ -628,8 +457,6 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
one_day_ago, current, connector.build_dummy_checkpoint()
|
||||
)
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
|
||||
DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL"
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:"
|
||||
|
||||
@@ -555,28 +555,6 @@ def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> None:
|
||||
"""Deletes all document by connector credential pair entries for a specific connector and credential.
|
||||
This is primarily used during connector deletion to ensure all references are removed
|
||||
before deleting the connector itself. This is crucial because connector_id is part of the
|
||||
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
|
||||
would otherwise try to set the foreign key to NULL, which fails for primary keys.
|
||||
|
||||
NOTE: Does not commit the transaction, this must be done by the caller.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
@@ -98,7 +98,7 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -132,16 +132,6 @@ def fetch_existing_llm_providers(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_provider(
|
||||
provider_name: str, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
@@ -187,7 +177,7 @@ def fetch_embedding_provider(
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
@@ -195,10 +185,10 @@ def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
@@ -206,18 +196,16 @@ def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_llm_provider_view(
|
||||
db_session: Session, provider_name: str
|
||||
) -> LLMProviderView | None:
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
|
||||
@@ -5,13 +5,13 @@ import re
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
from typing import NamedTuple
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
@@ -219,7 +219,7 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
|
||||
def read_pdf_file(
|
||||
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
|
||||
) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]:
|
||||
) -> tuple[str, dict, list[tuple[bytes, str]]]:
|
||||
"""
|
||||
Returns the text, basic PDF metadata, and optionally extracted images.
|
||||
"""
|
||||
@@ -282,13 +282,13 @@ def read_pdf_file(
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any],
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
embedded_images: List[Tuple[bytes, str]] = []
|
||||
|
||||
doc = docx.Document(file)
|
||||
|
||||
@@ -426,22 +426,14 @@ def extract_file_text(
|
||||
return ""
|
||||
|
||||
|
||||
class ExtractionResult(NamedTuple):
|
||||
"""Structured result from text and image extraction from various file types."""
|
||||
|
||||
text_content: str
|
||||
embedded_images: Sequence[tuple[bytes, str]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
def extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
) -> ExtractionResult:
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -450,9 +442,7 @@ def extract_text_and_images(
|
||||
# If the user doesn't want embedded images, unstructured is fine
|
||||
file.seek(0)
|
||||
text_content = unstructured_to_text(file, file_name)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=[], metadata={}
|
||||
)
|
||||
return (text_content, [])
|
||||
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
@@ -460,76 +450,54 @@ def extract_text_and_images(
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
return (text_content, images)
|
||||
|
||||
# PDF example: we do not show complicated PDF image extraction here
|
||||
# so we simply extract text for now and skip images.
|
||||
if extension == ".pdf":
|
||||
file.seek(0)
|
||||
text_content, pdf_metadata, images = read_pdf_file(
|
||||
file, pdf_pass, extract_images=True
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata=pdf_metadata
|
||||
)
|
||||
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
|
||||
return (text_content, images)
|
||||
|
||||
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
|
||||
# You can do something similar to docx if needed.
|
||||
if extension == ".pptx":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=pptx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
return (pptx_to_text(file), [])
|
||||
|
||||
if extension == ".xlsx":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=xlsx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
return (xlsx_to_text(file), [])
|
||||
|
||||
if extension == ".eml":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=eml_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
return (eml_to_text(file), [])
|
||||
|
||||
if extension == ".epub":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=epub_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
return (epub_to_text(file), [])
|
||||
|
||||
if extension == ".html":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=parse_html_page_basic(file),
|
||||
embedded_images=[],
|
||||
metadata={},
|
||||
)
|
||||
return (parse_html_page_basic(file), [])
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
file.seek(0)
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
text_content_raw, _ = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content_raw,
|
||||
embedded_images=[],
|
||||
metadata=file_metadata,
|
||||
)
|
||||
return (text_content_raw, [])
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
file.seek(0)
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
return ("", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
return ("", [])
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
import puremagic
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
@@ -14,7 +12,6 @@ from onyx.db.pg_file_store import delete_pgfilestore_by_file_name
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.db.pg_file_store import read_lobj
|
||||
from onyx.db.pg_file_store import upsert_pgfilestore
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
|
||||
|
||||
class FileStore(ABC):
|
||||
@@ -143,18 +140,6 @@ class PostgresBackedFileStore(FileStore):
|
||||
self.db_session.rollback()
|
||||
raise
|
||||
|
||||
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
|
||||
mime_type: str = "application/octet-stream"
|
||||
try:
|
||||
file_io = self.read_file(filename, mode="b")
|
||||
file_content = file_io.read()
|
||||
matches = puremagic.magic_string(file_content)
|
||||
if matches:
|
||||
mime_type = cast(str, matches[0].mime_type)
|
||||
return FileWithMimeType(data=file_content, mime_type=mime_type)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_default_file_store(db_session: Session) -> FileStore:
|
||||
# The only supported file store now is the Postgres File Store
|
||||
|
||||
@@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -62,7 +62,7 @@ def get_llms_for_persona(
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
@@ -106,7 +106,7 @@ def get_default_llm_with_vision(
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
|
||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
@@ -148,7 +148,7 @@ def get_default_llm_with_vision(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
LLMProviderView.from_model(provider), provider.default_vision_model
|
||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -56,9 +56,7 @@ BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
|
||||
# litellm has split them into two lists :(
|
||||
for model in litellm.bedrock_models + litellm.bedrock_converse_models
|
||||
for model in litellm.bedrock_models
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
|
||||
@@ -170,8 +170,7 @@ def handle_message(
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
|
||||
|
||||
# NOTE: always respond in the DMs, as long the default config is not disabled.
|
||||
if respond_tag_only and not bypass_filters and not is_bot_dm:
|
||||
if respond_tag_only and not bypass_filters:
|
||||
logger.info(
|
||||
"Skipping message since the channel is configured such that "
|
||||
"OnyxBot only responds to tags"
|
||||
|
||||
@@ -41,7 +41,6 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.slack_bot import fetch_slack_bot
|
||||
from onyx.db.slack_bot import fetch_slack_bots
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@@ -520,25 +519,6 @@ class SlackbotHandler:
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
|
||||
# skip cases where the bot is disabled in the web UI
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_bot = fetch_slack_bot(
|
||||
db_session=db_session, slack_bot_id=client.slack_bot_id
|
||||
)
|
||||
if not slack_bot:
|
||||
logger.error(
|
||||
f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request."
|
||||
)
|
||||
return False
|
||||
|
||||
if not slack_bot.enabled:
|
||||
logger.info(
|
||||
f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request."
|
||||
)
|
||||
return False
|
||||
|
||||
if req.type == "events_api":
|
||||
# Verify channel is valid
|
||||
event = cast(dict[str, Any], req.payload.get("event", {}))
|
||||
|
||||
@@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
@@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -49,27 +49,11 @@ def fetch_llm_options(
|
||||
def test_llm_configuration(
|
||||
test_llm_request: TestLLMRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Test regular llm and fast llm settings"""
|
||||
|
||||
# the api key is sanitized if we are testing a provider already in the system
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
test_llm_request.name, db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_api_key = existing_provider.api_key
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
@@ -85,7 +69,7 @@ def test_llm_configuration(
|
||||
fast_llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.fast_default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
@@ -135,17 +119,11 @@ def test_default_provider(
|
||||
def list_llm_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session):
|
||||
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||
if full_llm_provider.api_key:
|
||||
full_llm_provider.api_key = (
|
||||
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
|
||||
)
|
||||
llm_provider_list.append(full_llm_provider)
|
||||
|
||||
return llm_provider_list
|
||||
) -> list[FullLLMProvider]:
|
||||
return [
|
||||
FullLLMProvider.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -157,11 +135,11 @@ def put_llm_provider(
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderView:
|
||||
) -> FullLLMProvider:
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
|
||||
existing_provider = fetch_provider(db_session, llm_provider.name)
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -183,11 +161,6 @@ def put_llm_provider(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider.api_key_changed:
|
||||
llm_provider.api_key = existing_provider.api_key
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
@@ -261,7 +234,7 @@ def get_vision_capable_providers(
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_dict = LLMProviderView.from_model(provider).model_dump()
|
||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||
provider_dict["vision_models"] = vision_models
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
|
||||
@@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
@@ -77,19 +76,16 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
model_names: list[str] | None = None
|
||||
api_key_changed: bool = False
|
||||
|
||||
|
||||
class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -115,7 +111,7 @@ class LLMProviderView(LLMProvider):
|
||||
)
|
||||
|
||||
|
||||
class VisionProviderResponse(LLMProviderView):
|
||||
class VisionProviderResponse(FullLLMProvider):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
||||
@@ -32,14 +32,10 @@ from onyx.server.manage.models import SlackChannelConfig
|
||||
from onyx.server.manage.models import SlackChannelConfigCreationRequest
|
||||
from onyx.server.manage.validate_tokens import validate_app_token
|
||||
from onyx.server.manage.validate_tokens import validate_bot_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@@ -261,6 +257,9 @@ def create_bot(
|
||||
# Create a default Slack channel config
|
||||
default_channel_config = ChannelConfig(
|
||||
channel_name=None,
|
||||
respond_member_group_list=[],
|
||||
answer_filters=[],
|
||||
follow_up_tags=[],
|
||||
respond_tag_only=True,
|
||||
)
|
||||
insert_slack_channel_config(
|
||||
@@ -368,9 +367,7 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches all channels in the Slack workspace using the conversations_list API.
|
||||
This includes both public and private channels that are visible to the app,
|
||||
not just the ones the bot is a member of.
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
@@ -379,26 +376,26 @@ def get_all_channels_from_slack_api(
|
||||
status_code=404, detail="Bot token not found for the given bot ID"
|
||||
)
|
||||
|
||||
client = WebClient(token=tokens["bot_token"], timeout=1)
|
||||
client = WebClient(token=tokens["bot_token"])
|
||||
all_channels = []
|
||||
next_cursor = None
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
# Use conversations_list to get all channels in the workspace (including ones the bot is not a member of)
|
||||
# Use users_conversations with limited pagination
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.conversations_list(
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.conversations_list(
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
@@ -434,7 +431,6 @@ def get_all_channels_from_slack_api(
|
||||
|
||||
except SlackApiError as e:
|
||||
# Handle rate limiting or other API errors
|
||||
logger.exception("Error fetching channels from Slack API")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error fetching channels from Slack API: {str(e)}",
|
||||
|
||||
@@ -351,11 +351,9 @@ def remove_invited_user(
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||
|
||||
if MULTI_TENANT:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.file import OnyxStaticFileManager
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_ee_implementation_or_noop,
|
||||
)
|
||||
|
||||
|
||||
class OnyxRuntime:
|
||||
"""Used by the application to get the final runtime value of a setting.
|
||||
|
||||
Rationale: Settings and overrides may be persisted in multiple places, including the
|
||||
DB, Redis, env vars, and default constants, etc. The logic to present a final
|
||||
setting to the application should be centralized and in one place.
|
||||
|
||||
Example: To get the logo for the application, one must check the DB for an override,
|
||||
use the override if present, fall back to the filesystem if not present, and worry
|
||||
about enterprise or not enterprise.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_with_static_fallback(
|
||||
db_filename: str | None, static_filename: str
|
||||
) -> FileWithMimeType:
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
if db_filename:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
file_store = PostgresBackedFileStore(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(db_filename)
|
||||
|
||||
if not onyx_file:
|
||||
onyx_file = OnyxStaticFileManager.get_static(static_filename)
|
||||
|
||||
if not onyx_file:
|
||||
raise RuntimeError(
|
||||
f"Resource not found: db={db_filename} static={static_filename}"
|
||||
)
|
||||
|
||||
return onyx_file
|
||||
|
||||
@staticmethod
|
||||
def get_logo() -> FileWithMimeType:
|
||||
STATIC_FILENAME = "static/images/logo.png"
|
||||
|
||||
db_filename: str | None = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.enterprise_settings.store", "get_logo_filename", None
|
||||
)
|
||||
|
||||
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
|
||||
|
||||
@staticmethod
|
||||
def get_emailable_logo() -> FileWithMimeType:
|
||||
onyx_file = OnyxRuntime.get_logo()
|
||||
|
||||
# check dimensions and resize downwards if necessary or if not PNG
|
||||
image = Image.open(io.BytesIO(onyx_file.data))
|
||||
if (
|
||||
image.size[0] > ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
or image.size[1] > ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
or image.format != "PNG"
|
||||
):
|
||||
image.thumbnail(
|
||||
(ONYX_EMAILABLE_LOGO_MAX_DIM, ONYX_EMAILABLE_LOGO_MAX_DIM),
|
||||
Image.LANCZOS,
|
||||
) # maintains aspect ratio
|
||||
output_buffer = io.BytesIO()
|
||||
image.save(output_buffer, format="PNG")
|
||||
onyx_file = FileWithMimeType(
|
||||
data=output_buffer.getvalue(), mime_type="image/png"
|
||||
)
|
||||
|
||||
return onyx_file
|
||||
|
||||
@staticmethod
|
||||
def get_logotype() -> FileWithMimeType:
|
||||
STATIC_FILENAME = "static/images/logotype.png"
|
||||
|
||||
db_filename: str | None = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.enterprise_settings.store", "get_logotype_filename", None
|
||||
)
|
||||
|
||||
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
|
||||
@@ -307,7 +307,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
groups=[],
|
||||
display_model_names=OPEN_AI_MODEL_NAMES,
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
api_key_changed=True,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
@@ -324,7 +323,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||
)
|
||||
gpu_available = gpu_status_request(indexing=True)
|
||||
gpu_available = gpu_status_request()
|
||||
logger.info(f"GPU available: {gpu_available}")
|
||||
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -21,6 +21,7 @@ def build_tool_message(
|
||||
)
|
||||
|
||||
|
||||
# TODO: does this NEED to be BaseModel__v1?
|
||||
class ToolCallSummary(BaseModel):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import puremagic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class FileWithMimeType(BaseModel):
|
||||
data: bytes
|
||||
mime_type: str
|
||||
|
||||
|
||||
class OnyxStaticFileManager:
|
||||
"""Retrieve static resources with this class. Currently, these should all be located
|
||||
in the static directory ... e.g. static/images/logo.png"""
|
||||
|
||||
@staticmethod
|
||||
def get_static(filename: str) -> FileWithMimeType | None:
|
||||
try:
|
||||
mime_type: str = "application/octet-stream"
|
||||
with open(filename, "rb") as f:
|
||||
file_content = f.read()
|
||||
matches = puremagic.magic_string(file_content)
|
||||
if matches:
|
||||
mime_type = cast(str, matches[0].mime_type)
|
||||
except (OSError, FileNotFoundError, PermissionError) as e:
|
||||
logger.error(f"Failed to read file {filename}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected exception reading file {filename}: {e}")
|
||||
return None
|
||||
|
||||
return FileWithMimeType(data=file_content, mime_type=mime_type)
|
||||
@@ -1,5 +1,3 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
@@ -12,7 +10,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool = True) -> bool:
|
||||
if indexing:
|
||||
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
|
||||
else:
|
||||
@@ -29,14 +28,3 @@ def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
|
||||
raise # Re-raise exception to trigger a retry
|
||||
|
||||
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool) -> bool:
|
||||
return _get_gpu_status_from_model_server(indexing)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def fast_gpu_status_request(indexing: bool) -> bool:
|
||||
"""For use in sync flows, where we don't want to retry / we want to cache this."""
|
||||
return gpu_status_request(indexing=indexing)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import TypeVar
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def lazy_eval(func: Callable[[], R]) -> Callable[[], R]:
|
||||
@lru_cache(maxsize=1)
|
||||
def lazy_func() -> R:
|
||||
return func()
|
||||
|
||||
return lazy_func
|
||||
@@ -1,148 +1,18 @@
|
||||
import collections.abc
|
||||
import contextvars
|
||||
import copy
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import wait
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import overload
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVar("R")
|
||||
KT = TypeVar("KT") # Key type
|
||||
VT = TypeVar("VT") # Value type
|
||||
_T = TypeVar("_T") # Default type
|
||||
|
||||
|
||||
class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
"""
|
||||
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
|
||||
Implements the MutableMapping interface to provide a complete dictionary-like interface.
|
||||
|
||||
Example usage:
|
||||
# Create a thread-safe dictionary
|
||||
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
|
||||
|
||||
# Basic operations (atomic)
|
||||
safe_dict["key"] = 1
|
||||
value = safe_dict["key"]
|
||||
del safe_dict["key"]
|
||||
|
||||
# Bulk operations (atomic)
|
||||
safe_dict.update({"key1": 1, "key2": 2})
|
||||
"""
|
||||
|
||||
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
|
||||
self._dict: dict[KT, VT] = input_dict or {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
with self.lock:
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
with self.lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
with self.lock:
|
||||
del self._dict[key]
|
||||
|
||||
def __iter__(self) -> Iterator[KT]:
|
||||
# Return a snapshot of keys to avoid potential modification during iteration
|
||||
with self.lock:
|
||||
return iter(list(self._dict.keys()))
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._dict)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls.validate, handler(dict[KT, VT])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(v)
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the dictionary atomically."""
|
||||
with self.lock:
|
||||
self._dict.clear()
|
||||
|
||||
def copy(self) -> dict[KT, VT]:
|
||||
"""Return a shallow copy of the dictionary atomically."""
|
||||
with self.lock:
|
||||
return self._dict.copy()
|
||||
|
||||
@overload
|
||||
def get(self, key: KT) -> VT | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: KT, default: VT | _T) -> VT | _T:
|
||||
...
|
||||
|
||||
def get(self, key: KT, default: Any = None) -> Any:
|
||||
"""Get a value with a default, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def pop(self, key: KT, default: Any = None) -> Any:
|
||||
"""Remove and return a value with optional default, atomically."""
|
||||
with self.lock:
|
||||
if default is None:
|
||||
return self._dict.pop(key)
|
||||
return self._dict.pop(key, default)
|
||||
|
||||
def setdefault(self, key: KT, default: VT) -> VT:
|
||||
"""Set a default value if key is missing, atomically."""
|
||||
with self.lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def update(self, *args: Any, **kwargs: VT) -> None:
|
||||
"""Update the dictionary atomically from another mapping or from kwargs."""
|
||||
with self.lock:
|
||||
self._dict.update(*args, **kwargs)
|
||||
|
||||
def items(self) -> collections.abc.ItemsView[KT, VT]:
|
||||
"""Return a view of (key, value) pairs atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ItemsView(self)
|
||||
|
||||
def keys(self) -> collections.abc.KeysView[KT]:
|
||||
"""Return a view of keys atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.KeysView(self)
|
||||
|
||||
def values(self) -> collections.abc.ValuesView[VT]:
|
||||
"""Return a view of values atomically."""
|
||||
with self.lock:
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
@@ -320,27 +190,3 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
raise task.exception
|
||||
|
||||
return task.result
|
||||
|
||||
|
||||
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(g, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {
|
||||
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
|
||||
}
|
||||
|
||||
next_ind = len(gens)
|
||||
while future_to_index:
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
if result is not None:
|
||||
yield result
|
||||
future_to_index[
|
||||
executor.submit(_next_or_none, ind, gens[ind])
|
||||
] = next_ind
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
@@ -38,7 +38,7 @@ langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.63.8
|
||||
litellm==1.61.16
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -47,16 +47,15 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.66.3
|
||||
openai==1.61.0
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
psycopg2-binary==2.9.9
|
||||
puremagic==1.28
|
||||
pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.8.2
|
||||
PyGithub==2.5.0
|
||||
PyGithub==1.58.2
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==3.9.0
|
||||
python-pptx==0.6.23
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# USAGE: nohup ./docker_memory_tracking.sh &
|
||||
|
||||
# Set default output file or use the provided argument
|
||||
OUTPUT_FILE="./docker_stats.log"
|
||||
if [ $# -ge 1 ]; then
|
||||
OUTPUT_FILE="$1"
|
||||
fi
|
||||
|
||||
INTERVAL_SECONDS=600 # 10 minutes
|
||||
|
||||
# Create the output file if it doesn't exist, or append to it if it does
|
||||
touch "$OUTPUT_FILE"
|
||||
|
||||
echo "Docker stats will be collected every 10 minutes and saved to $OUTPUT_FILE"
|
||||
echo "Press Ctrl+C to stop the script"
|
||||
|
||||
# Function to handle script termination
|
||||
cleanup() {
|
||||
echo -e "\nStopping docker stats collection"
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Set up trap for clean exit
|
||||
trap cleanup SIGINT SIGTERM
|
||||
|
||||
# Main loop
|
||||
while true; do
|
||||
# Add timestamp
|
||||
echo -e "\n--- Docker Stats: $(date) ---" >> "$OUTPUT_FILE"
|
||||
|
||||
# Run docker stats for a single snapshot (--no-stream ensures it runs once)
|
||||
docker stats --no-stream --all >> "$OUTPUT_FILE"
|
||||
|
||||
# Wait for the next interval
|
||||
echo "Stats collected at $(date). Next collection in 10 minutes."
|
||||
sleep $INTERVAL_SECONDS
|
||||
done
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 6.6 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 44 KiB |
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -8,16 +7,15 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.utils import AttachmentProcessingResult
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
def confluence_connector() -> ConfluenceConnector:
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=space,
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
)
|
||||
@@ -34,15 +32,14 @@ def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@pytest.mark.skip(reason="Skipping this test")
|
||||
def test_confluence_connector_basic(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
@@ -53,14 +50,15 @@ def test_confluence_connector_basic(
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
txt_doc: Document | None = None
|
||||
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
|
||||
page_doc = doc
|
||||
elif ".txt" in doc.semantic_identifier:
|
||||
txt_doc = doc
|
||||
elif doc.semantic_identifier == "Page Within A Page":
|
||||
page_within_a_page_doc = doc
|
||||
else:
|
||||
pass
|
||||
|
||||
assert page_within_a_page_doc is not None
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
@@ -81,7 +79,7 @@ def test_confluence_connector_basic(
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_doc.sections) == 2 # page text + attachment text
|
||||
assert len(page_doc.sections) == 1
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
@@ -90,65 +88,13 @@ def test_confluence_connector_basic(
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
text_attachment_section = page_doc.sections[1]
|
||||
assert text_attachment_section.text == "small"
|
||||
assert text_attachment_section.link
|
||||
assert text_attachment_section.link.endswith("small-file.txt")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_confluence_connector_skip_images(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 8
|
||||
|
||||
|
||||
def mock_process_image_attachment(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AttachmentProcessingResult:
|
||||
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
|
||||
be done as a rule to begin with, but life is not perfect. Fix it later"""
|
||||
|
||||
return AttachmentProcessingResult(
|
||||
text="Hi_text",
|
||||
file_name="Hi_filename",
|
||||
error=None,
|
||||
assert txt_doc is not None
|
||||
assert txt_doc.semantic_identifier == "small-file.txt"
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@patch(
|
||||
"onyx.connectors.confluence.utils._process_image_attachment",
|
||||
side_effect=mock_process_image_attachment,
|
||||
)
|
||||
def test_confluence_connector_allow_images(
|
||||
mock_get_api_key: MagicMock,
|
||||
mock_process_image_attachment: MagicMock,
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(True)
|
||||
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 12
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector() -> GithubConnector:
|
||||
connector = GithubConnector(
|
||||
repo_owner="onyx-dot-app",
|
||||
repositories="documentation",
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_github_connector_basic(github_connector: GithubConnector) -> None:
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=github_connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) > 0 # We expect at least one PR to exist
|
||||
|
||||
# Test the first document's structure
|
||||
doc = docs[0]
|
||||
|
||||
# Verify basic document properties
|
||||
assert doc.source == DocumentSource.GITHUB
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
# Verify GitHub-specific properties
|
||||
assert "github.com" in doc.id # Should be a GitHub URL
|
||||
assert doc.metadata is not None
|
||||
assert "state" in doc.metadata
|
||||
assert "merged" in doc.metadata
|
||||
|
||||
# Verify sections
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.link == doc.id # Section link should match document ID
|
||||
assert isinstance(section.text, str) # Should have some text content
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
@@ -137,22 +136,3 @@ def google_drive_service_acct_connector_factory() -> (
|
||||
return connector
|
||||
|
||||
return _connector_factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_resource_limits() -> None:
|
||||
# the google sdk is aggressive about using up file descriptors and
|
||||
# macos is stingy ... these tests will fail randomly unless the descriptor limit is raised
|
||||
RLIMIT_MINIMUM = 2048
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
desired_soft = min(RLIMIT_MINIMUM, hard) # Pick your target here
|
||||
|
||||
print(f"Open file limit: soft={soft} hard={hard} soft_required={RLIMIT_MINIMUM}")
|
||||
|
||||
if soft < desired_soft:
|
||||
print(f"Raising open file limit: {soft} -> {desired_soft}")
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_soft, hard))
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
print(f"New open file limit: soft={soft} hard={hard}")
|
||||
return
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
ALL_FILES = list(range(0, 60))
|
||||
SHARED_DRIVE_FILES = list(range(20, 25))
|
||||
@@ -24,7 +21,6 @@ FOLDER_2_FILE_IDS = list(range(45, 50))
|
||||
FOLDER_2_1_FILE_IDS = list(range(50, 55))
|
||||
FOLDER_2_2_FILE_IDS = list(range(55, 60))
|
||||
SECTIONS_FILE_IDS = [61]
|
||||
FOLDER_3_FILE_IDS = list(range(62, 65))
|
||||
|
||||
PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS
|
||||
PUBLIC_FILE_IDS = list(range(55, 57))
|
||||
@@ -58,8 +54,6 @@ SECTIONS_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
|
||||
)
|
||||
|
||||
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
|
||||
|
||||
ADMIN_EMAIL = "admin@onyx-test.com"
|
||||
TEST_USER_1_EMAIL = "test_user_1@onyx-test.com"
|
||||
TEST_USER_2_EMAIL = "test_user_2@onyx-test.com"
|
||||
@@ -139,19 +133,17 @@ def filter_invalid_prefixes(names: set[str]) -> set[str]:
|
||||
return {name for name in names if name.startswith(_VALID_PREFIX)}
|
||||
|
||||
|
||||
def print_discrepancies(
|
||||
def print_discrepencies(
|
||||
expected: set[str],
|
||||
retrieved: set[str],
|
||||
) -> None:
|
||||
if expected != retrieved:
|
||||
expected_list = sorted(expected)
|
||||
retrieved_list = sorted(retrieved)
|
||||
print(expected_list)
|
||||
print(retrieved_list)
|
||||
print(expected)
|
||||
print(retrieved)
|
||||
print("Extra:")
|
||||
print(sorted(retrieved - expected))
|
||||
print(retrieved - expected)
|
||||
print("Missing:")
|
||||
print(sorted(expected - retrieved))
|
||||
print(expected - retrieved)
|
||||
|
||||
|
||||
def _get_expected_file_content(file_id: int) -> str:
|
||||
@@ -161,14 +153,10 @@ def _get_expected_file_content(file_id: int) -> str:
|
||||
return file_text_template.format(file_id)
|
||||
|
||||
|
||||
def assert_expected_docs_in_retrieved_docs(
|
||||
def assert_retrieved_docs_match_expected(
|
||||
retrieved_docs: list[Document],
|
||||
expected_file_ids: Sequence[int],
|
||||
) -> None:
|
||||
"""NOTE: as far as i can tell this does NOT assert for an exact match.
|
||||
it only checks to see if that the expected file id's are IN the retrieved doc list
|
||||
"""
|
||||
|
||||
expected_file_names = {
|
||||
file_name_template.format(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
@@ -176,10 +164,8 @@ def assert_expected_docs_in_retrieved_docs(
|
||||
_get_expected_file_content(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
|
||||
retrieved_docs.sort(key=lambda x: x.semantic_identifier)
|
||||
|
||||
for doc in retrieved_docs:
|
||||
print(f"retrieved doc: doc.semantic_identifier={doc.semantic_identifier}")
|
||||
print(f"doc.semantic_identifier: {doc.semantic_identifier}")
|
||||
|
||||
# Filter out invalid prefixes to prevent different tests from interfering with each other
|
||||
valid_retrieved_docs = [
|
||||
@@ -204,23 +190,15 @@ def assert_expected_docs_in_retrieved_docs(
|
||||
)
|
||||
|
||||
# Check file names
|
||||
print_discrepancies(
|
||||
print_discrepencies(
|
||||
expected=expected_file_names,
|
||||
retrieved=valid_retrieved_file_names,
|
||||
)
|
||||
assert expected_file_names == valid_retrieved_file_names
|
||||
|
||||
# Check file texts
|
||||
print_discrepancies(
|
||||
print_discrepencies(
|
||||
expected=expected_file_texts,
|
||||
retrieved=valid_retrieved_texts,
|
||||
)
|
||||
assert expected_file_texts == valid_retrieved_texts
|
||||
|
||||
|
||||
def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
|
||||
return load_all_docs_from_checkpoint_connector(
|
||||
connector,
|
||||
0,
|
||||
time.time(),
|
||||
)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
assert_retrieved_docs_match_expected,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -21,7 +23,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
@@ -46,7 +47,9 @@ def test_include_all(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should get everything in shared and admin's My Drive with oauth
|
||||
expected_file_ids = (
|
||||
@@ -62,7 +65,7 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -86,7 +89,9 @@ def test_include_shared_drives_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
@@ -100,7 +105,7 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -124,11 +129,13 @@ def test_include_my_drives_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -153,7 +160,9 @@ def test_drive_one_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in drive_urls]),
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
@@ -161,7 +170,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -187,7 +196,9 @@ def test_folder_and_shared_drive(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in drive_urls]),
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
@@ -198,7 +209,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -232,7 +243,9 @@ def test_folders_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
FOLDER_1_1_FILE_IDS
|
||||
@@ -241,7 +254,7 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -268,10 +281,12 @@ def test_personal_folders_only(
|
||||
my_drive_emails=None,
|
||||
shared_drive_urls=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL
|
||||
|
||||
|
||||
@@ -36,7 +37,9 @@ def test_google_drive_sections(
|
||||
my_drive_emails=None,
|
||||
)
|
||||
for connector in [oauth_connector, service_acct_connector]:
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Verify we got the 1 doc with sections
|
||||
assert len(retrieved_docs) == 1
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
assert_retrieved_docs_match_expected,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -21,7 +23,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
@@ -51,7 +52,9 @@ def test_include_all(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should get everything
|
||||
expected_file_ids = (
|
||||
@@ -70,39 +73,12 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_include_shared_drives_only_with_size_threshold(
|
||||
mock_get_api_key: MagicMock,
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_include_shared_drives_only_with_size_threshold")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_folder_urls=None,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
# this threshold will skip one file
|
||||
connector.size_threshold = 16384
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
assert len(retrieved_docs) == 50
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
@@ -121,8 +97,9 @@ def test_include_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
@@ -136,10 +113,7 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
|
||||
assert len(retrieved_docs) == 51
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -163,7 +137,9 @@ def test_include_my_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should only get everyone's My Drives
|
||||
expected_file_ids = (
|
||||
@@ -173,7 +149,7 @@ def test_include_my_drives_only(
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -198,7 +174,9 @@ def test_drive_one_only(
|
||||
shared_drive_urls=",".join([str(url) for url in urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# We ignore shared_drive_urls if include_shared_drives is False
|
||||
expected_file_ids = (
|
||||
@@ -207,7 +185,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -233,7 +211,9 @@ def test_folder_and_shared_drive(
|
||||
shared_folder_urls=",".join([str(url) for url in folder_urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# Should get everything except for the top level files in drive 2
|
||||
expected_file_ids = (
|
||||
@@ -245,7 +225,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -279,7 +259,9 @@ def test_folders_only(
|
||||
shared_folder_urls=",".join([str(url) for url in folder_urls]),
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
FOLDER_1_1_FILE_IDS
|
||||
@@ -288,7 +270,7 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -316,10 +298,12 @@ def test_specific_emails(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -346,10 +330,12 @@ def get_specific_folders_in_my_drive(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_I
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import print_discrepancies
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import print_discrepencies
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import PUBLIC_RANGE
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
@@ -83,7 +83,7 @@ def assert_correct_access_for_user(
|
||||
expected_file_names = {file_name_template.format(i) for i in all_accessible_ids}
|
||||
|
||||
filtered_retrieved_file_names = filter_invalid_prefixes(retrieved_file_names)
|
||||
print_discrepancies(expected_file_names, filtered_retrieved_file_names)
|
||||
print_discrepencies(expected_file_names, filtered_retrieved_file_names)
|
||||
|
||||
assert expected_file_names == filtered_retrieved_file_names
|
||||
|
||||
@@ -175,7 +175,7 @@ def test_all_permissions(
|
||||
|
||||
# Should get everything
|
||||
filtered_retrieved_file_names = filter_invalid_prefixes(found_file_names)
|
||||
print_discrepancies(expected_file_names, filtered_retrieved_file_names)
|
||||
print_discrepencies(expected_file_names, filtered_retrieved_file_names)
|
||||
assert expected_file_names == filtered_retrieved_file_names
|
||||
|
||||
group_map = get_group_map(google_drive_connector)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
assert_retrieved_docs_match_expected,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
|
||||
@@ -36,7 +37,9 @@ def test_all(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from my drive
|
||||
@@ -50,7 +53,7 @@ def test_all(
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -74,7 +77,9 @@ def test_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from shared drives
|
||||
@@ -83,7 +88,7 @@ def test_shared_drives_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -107,14 +112,16 @@ def test_shared_with_me_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files shared with me from admin
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -138,11 +145,13 @@ def test_my_drive_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -166,13 +175,15 @@ def test_shared_my_drive_folder(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = (
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -196,10 +207,12 @@ def test_shared_drive_folder(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
retrieved_docs: list[Document] = []
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
retrieved_docs.extend(doc_batch)
|
||||
|
||||
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
assert_retrieved_docs_match_expected(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_highspot_data.json") -> dict:
|
||||
"""Load test data from JSON file."""
|
||||
current_dir = Path(__file__).parent
|
||||
with open(current_dir / file_name, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def highspot_connector() -> HighspotConnector:
|
||||
"""Create a Highspot connector with credentials from environment variables."""
|
||||
# Check if required environment variables are set
|
||||
if not os.environ.get("HIGHSPOT_KEY") or not os.environ.get("HIGHSPOT_SECRET"):
|
||||
pytest.fail("HIGHSPOT_KEY or HIGHSPOT_SECRET environment variables not set")
|
||||
|
||||
connector = HighspotConnector(
|
||||
spot_names=["Test content"], # Use specific spot name instead of empty list
|
||||
batch_size=10, # Smaller batch size for testing
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"highspot_key": os.environ["HIGHSPOT_KEY"],
|
||||
"highspot_secret": os.environ["HIGHSPOT_SECRET"],
|
||||
"highspot_url": os.environ.get(
|
||||
"HIGHSPOT_URL", "https://api-su2.highspot.com/v1.0/"
|
||||
),
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Accessing postgres that isn't available in connector only tests",
|
||||
)
|
||||
def test_highspot_connector_basic(highspot_connector: HighspotConnector) -> None:
|
||||
"""Test basic functionality of the Highspot connector."""
|
||||
all_docs: list[Document] = []
|
||||
test_data = load_test_data()
|
||||
target_test_doc_id = test_data.get("target_doc_id")
|
||||
target_test_doc: Document | None = None
|
||||
|
||||
# Test loading documents
|
||||
for doc_batch in highspot_connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == f"HIGHSPOT_{target_test_doc_id}":
|
||||
target_test_doc = doc
|
||||
|
||||
# Verify documents were loaded
|
||||
assert len(all_docs) > 0
|
||||
|
||||
# If we have a specific test document ID, validate it
|
||||
if target_test_doc_id and target_test_doc is not None:
|
||||
assert target_test_doc.semantic_identifier == test_data.get(
|
||||
"semantic_identifier"
|
||||
)
|
||||
assert target_test_doc.source == DocumentSource.HIGHSPOT
|
||||
assert target_test_doc.metadata is not None
|
||||
|
||||
assert len(target_test_doc.sections) == 1
|
||||
section = target_test_doc.sections[0]
|
||||
assert section.link is not None
|
||||
# Only check if content exists, as exact content might change
|
||||
assert section.text is not None
|
||||
assert len(section.text) > 0
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Possibly accessing postgres that isn't available in connector only tests",
|
||||
)
|
||||
def test_highspot_connector_slim(highspot_connector: HighspotConnector) -> None:
|
||||
"""Test slim document retrieval."""
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in highspot_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get all doc IDs from the slim connector
|
||||
all_slim_doc_ids = set()
|
||||
for slim_doc_batch in highspot_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# The set of full doc IDs should be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
# Make sure we actually got some documents
|
||||
assert len(all_slim_doc_ids) > 0
|
||||
|
||||
|
||||
def test_highspot_connector_validate_credentials(
|
||||
highspot_connector: HighspotConnector,
|
||||
) -> None:
|
||||
"""Test credential validation."""
|
||||
assert highspot_connector.validate_credentials() is True
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"target_doc_id": "67cd8eb35d3ee0487de2e704",
|
||||
"semantic_identifier": "Highspot in Action _ Salesforce Integration",
|
||||
"link": "https://www.highspot.com/items/67cd8eb35d3ee0487de2e704"
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -25,13 +24,15 @@ def jira_connector() -> JiraConnector:
|
||||
|
||||
|
||||
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=jira_connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) == 1
|
||||
doc = docs[0]
|
||||
doc_batch_generator = jira_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
doc = doc_batch[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "AS-2: test123small"
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
def load_all_docs_from_checkpoint_connector(
|
||||
connector: CheckpointConnector[CT],
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[Document]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = cast(CT, connector.build_dummy_checkpoint())
|
||||
documents: list[Document] = []
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper[CT]()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
raise RuntimeError(f"Failed to load documents: {failure}")
|
||||
if document is not None:
|
||||
documents.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
def load_everything_from_checkpoint_connector(
|
||||
connector: CheckpointConnector[CT],
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[Document | ConnectorFailure]:
|
||||
"""Like load_all_docs_from_checkpoint_connector but returns both documents and failures"""
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
outputs: list[Document | ConnectorFailure] = []
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper[CT]()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
outputs.append(failure)
|
||||
if document is not None:
|
||||
outputs.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return outputs
|
||||
@@ -2,14 +2,12 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_zendesk_data.json") -> dict[str, dict]:
|
||||
@@ -52,7 +50,7 @@ def get_credentials() -> dict[str, str]:
|
||||
def test_zendesk_connector_basic(
|
||||
request: pytest.FixtureRequest, connector_fixture: str
|
||||
) -> None:
|
||||
connector = cast(ZendeskConnector, request.getfixturevalue(connector_fixture))
|
||||
connector = request.getfixturevalue(connector_fixture)
|
||||
test_data = load_test_data()
|
||||
all_docs: list[Document] = []
|
||||
target_test_doc_id: str
|
||||
@@ -63,11 +61,12 @@ def test_zendesk_connector_basic(
|
||||
|
||||
target_doc: Document | None = None
|
||||
|
||||
for doc in load_all_docs_from_checkpoint_connector(connector, 0, time.time()):
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
|
||||
assert len(all_docs) > 0, "No documents were retrieved from the connector"
|
||||
assert (
|
||||
@@ -112,10 +111,8 @@ def test_zendesk_connector_basic(
|
||||
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
|
||||
# Get full doc IDs
|
||||
all_full_doc_ids = set()
|
||||
for doc in load_all_docs_from_checkpoint_connector(
|
||||
zendesk_article_connector, 0, time.time()
|
||||
):
|
||||
all_full_doc_ids.add(doc.id)
|
||||
for doc_batch in zendesk_article_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get slim doc IDs
|
||||
all_slim_doc_ids = set()
|
||||
|
||||
@@ -12,8 +12,8 @@ The idea is that each test can use the manager class to create (.create()) a "te
|
||||
## Instructions for Running Integration Tests Locally
|
||||
|
||||
1. Launch onyx (using Docker or running with a debugger), ensuring the API server is running on port 8080.
|
||||
- If you'd like to set environment variables, you can do so by creating a `.env` file in the onyx/backend/tests/integration/ directory.
|
||||
- Onyx MUST be launched with AUTH_TYPE=basic and ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
a. If you'd like to set environment variables, you can do so by creating a `.env` file in the onyx/backend/tests/integration/ directory.
|
||||
b. Onyx MUST be launched with AUTH_TYPE=basic and ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
2. Navigate to `onyx/backend`.
|
||||
3. Run the following command in the terminal:
|
||||
```sh
|
||||
@@ -28,14 +28,6 @@ The idea is that each test can use the manager class to create (.create()) a "te
|
||||
pytest -s tests/integration/tests/path_to/test_file.py::test_function_name
|
||||
```
|
||||
|
||||
Running some single tests require the `mock_connector_server` container to be running. If the above doesn't work,
|
||||
navigate to `backend/tests/integration/mock_services` and run
|
||||
```sh
|
||||
docker compose -f docker-compose.mock-it-services.yml -p mock-it-services-stack up -d
|
||||
```
|
||||
You will have to modify the networks section of the docker-compose file to `<your stack name>_default` if you brought up the standard
|
||||
onyx services with a name different from the default `onyx-stack`.
|
||||
|
||||
## Guidelines for Writing Integration Tests
|
||||
|
||||
- As authentication is currently required for all tests, each test should start by creating a user.
|
||||
|
||||
@@ -3,8 +3,8 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
@@ -39,7 +39,6 @@ class LLMProviderManager:
|
||||
groups=groups or [],
|
||||
display_model_names=None,
|
||||
model_names=None,
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
llm_response = requests.put(
|
||||
@@ -91,7 +90,7 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[LLMProviderView]:
|
||||
) -> list[FullLLMProvider]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=user_performing_action.headers
|
||||
@@ -99,7 +98,7 @@ class LLMProviderManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
return [FullLLMProvider(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -112,19 +111,18 @@ class LLMProviderManager:
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"LLM Provider {llm_provider.id} found but should be deleted"
|
||||
f"User group {llm_provider.id} found but should be deleted"
|
||||
)
|
||||
fetched_llm_groups = set(fetched_llm_provider.groups)
|
||||
llm_provider_groups = set(llm_provider.groups)
|
||||
|
||||
# NOTE: returned api keys are sanitized and should not match
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.api_key == fetched_llm_provider.api_key
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
raise ValueError(f"User group {llm_provider.id} not found")
|
||||
|
||||
@@ -7,7 +7,7 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import InputType
|
||||
@@ -54,9 +54,9 @@ def test_mock_connector_basic_flow(
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
@@ -128,9 +128,9 @@ def test_mock_connector_with_failures(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [doc2_failure.model_dump(mode="json")],
|
||||
}
|
||||
],
|
||||
@@ -208,9 +208,9 @@ def test_mock_connector_failure_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [
|
||||
doc2_failure.model_dump(mode="json"),
|
||||
ConnectorFailure(
|
||||
@@ -292,9 +292,9 @@ def test_mock_connector_failure_recovery(
|
||||
doc1.model_dump(mode="json"),
|
||||
doc2.model_dump(mode="json"),
|
||||
],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
@@ -372,23 +372,23 @@ def test_mock_connector_checkpoint_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
|
||||
"checkpoint": MockConnectorCheckpoint(
|
||||
has_more=True, last_document_id=docs_batch_1[-1].id
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [doc2.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(
|
||||
has_more=True, last_document_id=doc2.id
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [],
|
||||
# should never hit this, unhandled exception happens first
|
||||
"checkpoint": MockConnectorCheckpoint(
|
||||
has_more=False, last_document_id=doc2.id
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error",
|
||||
@@ -446,16 +446,12 @@ def test_mock_connector_checkpoint_recovery(
|
||||
initial_checkpoints = response.json()
|
||||
|
||||
# Verify we got the expected checkpoints in order
|
||||
assert len(initial_checkpoints) == 3
|
||||
assert initial_checkpoints[0] == {
|
||||
"has_more": True,
|
||||
"last_document_id": None,
|
||||
} # Initial empty checkpoint
|
||||
assert initial_checkpoints[1] == {
|
||||
"has_more": True,
|
||||
"last_document_id": docs_batch_1[-1].id,
|
||||
}
|
||||
assert initial_checkpoints[2] == {"has_more": True, "last_document_id": doc2.id}
|
||||
assert len(initial_checkpoints) > 0
|
||||
assert (
|
||||
initial_checkpoints[0]["checkpoint_content"] == {}
|
||||
) # Initial empty checkpoint
|
||||
assert initial_checkpoints[1]["checkpoint_content"] == {}
|
||||
assert initial_checkpoints[2]["checkpoint_content"] == {}
|
||||
|
||||
# Reset the mock server for the next run
|
||||
response = mock_server_client.post("/reset")
|
||||
@@ -467,8 +463,8 @@ def test_mock_connector_checkpoint_recovery(
|
||||
json=[
|
||||
{
|
||||
"documents": [doc3.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(
|
||||
has_more=False, last_document_id=doc3.id
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
@@ -519,4 +515,4 @@ def test_mock_connector_checkpoint_recovery(
|
||||
|
||||
# Verify the recovery run started from the last successful checkpoint
|
||||
assert len(recovery_checkpoints) == 1
|
||||
assert recovery_checkpoints[0] == {"has_more": True, "last_document_id": doc2.id}
|
||||
assert recovery_checkpoints[0]["checkpoint_content"] == {}
|
||||
|
||||
@@ -34,7 +34,6 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@@ -50,9 +49,6 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that returned key is sanitized
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
@@ -68,12 +64,10 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -87,7 +81,6 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
@@ -100,30 +93,6 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert (
|
||||
provider_data["api_key"] == "sk-0****0000"
|
||||
) # test that key was NOT updated due to api_key_changed not being set
|
||||
|
||||
# Update with api_key_changed properly set
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000001",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"api_key_changed": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
|
||||
|
||||
|
||||
def test_delete_llm_provider(reset: None) -> None:
|
||||
@@ -138,7 +107,6 @@ def test_delete_llm_provider(reset: None) -> None:
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"api_key": "sk-000000000000000000000000000000000000000000000000",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
|
||||
@@ -50,7 +50,7 @@ def answer_instance(
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
@@ -400,7 +400,7 @@ def test_no_slow_reranking(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
|
||||
@@ -1,441 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from github import Github
|
||||
from github import GithubException
|
||||
from github import RateLimitExceededException
|
||||
from github.Issue import Issue
|
||||
from github.PullRequest import PullRequest
|
||||
from github.RateLimit import RateLimit
|
||||
from github.Repository import Repository
|
||||
from github.Requester import Requester
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import SerializedRepository
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo_owner() -> str:
|
||||
return "test-org"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repositories() -> str:
|
||||
return "test-repo"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_client() -> MagicMock:
|
||||
"""Create a mock GitHub client with proper typing"""
|
||||
mock = MagicMock(spec=Github)
|
||||
# Add proper return typing for get_repo method
|
||||
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
|
||||
# Add proper return typing for get_organization method
|
||||
mock.get_organization = MagicMock()
|
||||
# Add proper return typing for get_user method
|
||||
mock.get_user = MagicMock()
|
||||
# Add proper return typing for get_rate_limit method
|
||||
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
|
||||
# Add requester for repository deserialization
|
||||
mock.requester = MagicMock(spec=Requester)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector(
|
||||
repo_owner: str, repositories: str, mock_github_client: MagicMock
|
||||
) -> Generator[GithubConnector, None, None]:
|
||||
connector = GithubConnector(
|
||||
repo_owner=repo_owner,
|
||||
repositories=repositories,
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.github_client = mock_github_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_pr() -> Callable[..., MagicMock]:
|
||||
def _create_mock_pr(
|
||||
number: int = 1,
|
||||
title: str = "Test PR",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
merged: bool = False,
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock PullRequest object"""
|
||||
mock_pr = MagicMock(spec=PullRequest)
|
||||
mock_pr.number = number
|
||||
mock_pr.title = title
|
||||
mock_pr.body = body
|
||||
mock_pr.state = state
|
||||
mock_pr.merged = merged
|
||||
mock_pr.updated_at = updated_at
|
||||
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
|
||||
return mock_pr
|
||||
|
||||
return _create_mock_pr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
def _create_mock_issue(
|
||||
number: int = 1,
|
||||
title: str = "Test Issue",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Issue object"""
|
||||
mock_issue = MagicMock(spec=Issue)
|
||||
mock_issue.number = number
|
||||
mock_issue.title = title
|
||||
mock_issue.body = body
|
||||
mock_issue.state = state
|
||||
mock_issue.updated_at = updated_at
|
||||
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
|
||||
mock_issue.pull_request = None # Not a PR
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_repo() -> Callable[..., MagicMock]:
|
||||
def _create_mock_repo(
|
||||
name: str = "test-repo",
|
||||
id: int = 1,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Repository object"""
|
||||
mock_repo = MagicMock(spec=Repository)
|
||||
mock_repo.name = name
|
||||
mock_repo.id = id
|
||||
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
|
||||
mock_repo.raw_data = {
|
||||
"id": str(id),
|
||||
"name": name,
|
||||
"full_name": f"test-org/{name}",
|
||||
"private": str(False),
|
||||
"description": "Test repository",
|
||||
}
|
||||
return mock_repo
|
||||
|
||||
return _create_mock_repo
|
||||
|
||||
|
||||
def test_load_from_checkpoint_happy_path(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint - happy path"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs and issues
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_pulls and get_issues methods
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got all documents and final has_more=False
|
||||
assert len(outputs) == 4
|
||||
|
||||
repo_batch = outputs[0]
|
||||
assert len(repo_batch.items) == 0
|
||||
assert repo_batch.next_checkpoint.has_more is True
|
||||
|
||||
# Check first batch (PRs)
|
||||
first_batch = outputs[1]
|
||||
assert len(first_batch.items) == 2
|
||||
assert isinstance(first_batch.items[0], Document)
|
||||
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
assert isinstance(first_batch.items[1], Document)
|
||||
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||
assert first_batch.next_checkpoint.curr_page == 1
|
||||
|
||||
# Check second batch (Issues)
|
||||
second_batch = outputs[2]
|
||||
assert len(second_batch.items) == 2
|
||||
assert isinstance(second_batch.items[0], Document)
|
||||
assert (
|
||||
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
|
||||
)
|
||||
assert isinstance(second_batch.items[1], Document)
|
||||
assert (
|
||||
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
|
||||
)
|
||||
assert second_batch.next_checkpoint.has_more
|
||||
|
||||
# Check third batch (finished checkpoint)
|
||||
third_batch = outputs[3]
|
||||
assert len(third_batch.items) == 0
|
||||
assert third_batch.next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PR
|
||||
mock_pr = create_mock_pr()
|
||||
|
||||
# Mock get_pulls to raise RateLimitExceededException on first call
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
|
||||
[mock_pr],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock rate limit reset time
|
||||
mock_rate_limit = MagicMock(spec=RateLimit)
|
||||
mock_rate_limit.core.reset = datetime.now(timezone.utc)
|
||||
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch(
|
||||
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
|
||||
) as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_repo(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an empty repository"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_pulls and get_issues to return empty lists
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.return_value = []
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.return_value = []
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert len(outputs[-1].items) == 0
|
||||
assert not outputs[-1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_prs_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only PRs enabled"""
|
||||
# Configure connector to only include PRs
|
||||
github_connector.include_prs = True
|
||||
github_connector.include_issues = False
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
|
||||
# Mock get_pulls method
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got PRs
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be PRs
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_issues_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only issues enabled"""
|
||||
# Configure connector to only include issues
|
||||
github_connector.include_prs = False
|
||||
github_connector.include_issues = True
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_issues method
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got issues
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be issues
|
||||
assert outputs[1].next_checkpoint.has_more
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"GitHub credential appears to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your GitHub token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"GitHub repository not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
github_connector: GithubConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
error = GithubException(status=status_code, data={}, headers={})
|
||||
|
||||
github_client = cast(Github, github_connector.github_client)
|
||||
get_repo_mock = cast(MagicMock, github_client.get_repo)
|
||||
get_repo_mock.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
github_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_contents to simulate successful access
|
||||
mock_repo.get_contents.return_value = MagicMock()
|
||||
|
||||
github_connector.validate_connector_settings()
|
||||
github_connector.github_client.get_repo.assert_called_once_with(
|
||||
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||
)
|
||||
@@ -1,436 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from jira import JIRA
|
||||
from jira import JIRAError
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnectorCheckpoint
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_base_url() -> str:
|
||||
return "https://jira.example.com"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_key() -> str:
|
||||
return "TEST"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_client() -> MagicMock:
|
||||
"""Create a mock JIRA client with proper typing"""
|
||||
mock = MagicMock(spec=JIRA)
|
||||
# Add proper return typing for search_issues method
|
||||
mock.search_issues = MagicMock()
|
||||
# Add proper return typing for project method
|
||||
mock.project = MagicMock()
|
||||
# Add proper return typing for projects method
|
||||
mock.projects = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector(
|
||||
jira_base_url: str, project_key: str, mock_jira_client: MagicMock
|
||||
) -> Generator[JiraConnector, None, None]:
|
||||
connector = JiraConnector(
|
||||
jira_base_url=jira_base_url,
|
||||
project_key=project_key,
|
||||
comment_email_blacklist=["blacklist@example.com"],
|
||||
labels_to_skip=["secret", "sensitive"],
|
||||
)
|
||||
connector._jira_client = mock_jira_client
|
||||
connector._jira_client.client_info.return_value = jira_base_url
|
||||
with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2):
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
def _create_mock_issue(
|
||||
key: str = "TEST-123",
|
||||
summary: str = "Test Issue",
|
||||
updated: str = "2023-01-01T12:00:00.000+0000",
|
||||
description: str = "Test Description",
|
||||
labels: list[str] | None = None,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Issue object"""
|
||||
mock_issue = MagicMock(spec=Issue)
|
||||
# Create fields attribute first
|
||||
mock_issue.fields = MagicMock()
|
||||
mock_issue.key = key
|
||||
mock_issue.fields.summary = summary
|
||||
mock_issue.fields.updated = updated
|
||||
mock_issue.fields.description = description
|
||||
mock_issue.fields.labels = labels or []
|
||||
|
||||
# Set up creator and assignee for testing owner extraction
|
||||
mock_issue.fields.creator = MagicMock()
|
||||
mock_issue.fields.creator.displayName = "Test Creator"
|
||||
mock_issue.fields.creator.emailAddress = "creator@example.com"
|
||||
|
||||
mock_issue.fields.assignee = MagicMock()
|
||||
mock_issue.fields.assignee.displayName = "Test Assignee"
|
||||
mock_issue.fields.assignee.emailAddress = "assignee@example.com"
|
||||
|
||||
# Set up priority, status, and resolution
|
||||
mock_issue.fields.priority = MagicMock()
|
||||
mock_issue.fields.priority.name = "High"
|
||||
|
||||
mock_issue.fields.status = MagicMock()
|
||||
mock_issue.fields.status.name = "In Progress"
|
||||
|
||||
mock_issue.fields.resolution = MagicMock()
|
||||
mock_issue.fields.resolution.name = "Fixed"
|
||||
|
||||
# Add raw field for accessing through API version check
|
||||
mock_issue.raw = {"fields": {"description": description}}
|
||||
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
|
||||
|
||||
def test_load_credentials(jira_connector: JiraConnector) -> None:
|
||||
"""Test loading credentials"""
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.build_jira_client"
|
||||
) as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
credentials = {
|
||||
"jira_user_email": "user@example.com",
|
||||
"jira_api_token": "token123",
|
||||
}
|
||||
|
||||
result = jira_connector.load_credentials(credentials)
|
||||
|
||||
mock_build_client.assert_called_once_with(
|
||||
credentials=credentials, jira_base=jira_connector.jira_base
|
||||
)
|
||||
assert result is None
|
||||
assert jira_connector._jira_client == mock_build_client.return_value
|
||||
|
||||
|
||||
def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None:
|
||||
"""Test JQL query generation with project specified"""
|
||||
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
query = jira_connector._get_jql_query(start, end)
|
||||
|
||||
# Check that the project part and time part are both in the query
|
||||
assert f'project = "{jira_connector.jira_project}"' in query
|
||||
assert "updated >= '2023-01-01 00:00'" in query
|
||||
assert "updated <= '2023-01-02 00:00'" in query
|
||||
assert " AND " in query
|
||||
|
||||
|
||||
def test_get_jql_query_without_project(jira_base_url: str) -> None:
|
||||
"""Test JQL query generation without project specified"""
|
||||
# Create connector without project key
|
||||
connector = JiraConnector(jira_base_url=jira_base_url)
|
||||
|
||||
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
query = connector._get_jql_query(start, end)
|
||||
|
||||
# Check that only time part is in the query
|
||||
assert "project =" not in query
|
||||
assert "updated >= '2023-01-01 00:00'" in query
|
||||
assert "updated <= '2023-01-02 00:00'" in query
|
||||
|
||||
|
||||
def test_load_from_checkpoint_happy_path(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint - happy path"""
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
|
||||
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
|
||||
|
||||
# Only mock the search_issues method
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[mock_issue3],
|
||||
[],
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
# Check that the documents were returned
|
||||
assert len(outputs) == 2
|
||||
|
||||
checkpoint_output1 = outputs[0]
|
||||
assert len(checkpoint_output1.items) == 2
|
||||
document1 = checkpoint_output1.items[0]
|
||||
assert isinstance(document1, Document)
|
||||
assert document1.id == "https://jira.example.com/browse/TEST-1"
|
||||
document2 = checkpoint_output1.items[1]
|
||||
assert isinstance(document2, Document)
|
||||
assert document2.id == "https://jira.example.com/browse/TEST-2"
|
||||
assert checkpoint_output1.next_checkpoint == JiraConnectorCheckpoint(
|
||||
offset=2,
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
checkpoint_output2 = outputs[1]
|
||||
assert len(checkpoint_output2.items) == 1
|
||||
document3 = checkpoint_output2.items[0]
|
||||
assert isinstance(document3, Document)
|
||||
assert document3.id == "https://jira.example.com/browse/TEST-3"
|
||||
assert checkpoint_output2.next_checkpoint == JiraConnectorCheckpoint(
|
||||
offset=3,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
# Check that search_issues was called with the right parameters
|
||||
assert search_issues_mock.call_count == 2
|
||||
args, kwargs = search_issues_mock.call_args_list[0]
|
||||
assert kwargs["startAt"] == 0
|
||||
assert kwargs["maxResults"] == PAGE_SIZE
|
||||
|
||||
args, kwargs = search_issues_mock.call_args_list[1]
|
||||
assert kwargs["startAt"] == 2
|
||||
assert kwargs["maxResults"] == PAGE_SIZE
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_issue_processing_error(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches"""
|
||||
# Set up mocked issues for first batch
|
||||
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
|
||||
# Set up mocked issues for second batch
|
||||
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
|
||||
mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4")
|
||||
|
||||
# Mock search_issues to return our mock issues in batches
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.side_effect = [
|
||||
[mock_issue1, mock_issue2], # First batch
|
||||
[mock_issue3, mock_issue4], # Second batch
|
||||
[], # Empty batch to indicate end
|
||||
]
|
||||
|
||||
# Mock process_jira_issue to succeed for some issues and fail for others
|
||||
def mock_process_side_effect(
|
||||
jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any
|
||||
) -> Document | None:
|
||||
if issue.key in ["TEST-1", "TEST-3"]:
|
||||
return Document(
|
||||
id=f"https://jira.example.com/browse/{issue.key}",
|
||||
sections=[],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
metadata={},
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Processing error for {issue.key}")
|
||||
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.process_jira_issue"
|
||||
) as mock_process:
|
||||
mock_process.side_effect = mock_process_side_effect
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
assert len(outputs) == 3
|
||||
|
||||
# Check first batch
|
||||
first_batch = outputs[0]
|
||||
assert len(first_batch.items) == 2
|
||||
# First item should be successful
|
||||
assert isinstance(first_batch.items[0], Document)
|
||||
assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1"
|
||||
# Second item should be a failure
|
||||
assert isinstance(first_batch.items[1], ConnectorFailure)
|
||||
assert first_batch.items[1].failed_document is not None
|
||||
assert first_batch.items[1].failed_document.document_id == "TEST-2"
|
||||
assert "Failed to process Jira issue" in first_batch.items[1].failure_message
|
||||
# Check checkpoint indicates more items (full batch)
|
||||
assert first_batch.next_checkpoint.has_more is True
|
||||
assert first_batch.next_checkpoint.offset == 2
|
||||
|
||||
# Check second batch
|
||||
second_batch = outputs[1]
|
||||
assert len(second_batch.items) == 2
|
||||
# First item should be successful
|
||||
assert isinstance(second_batch.items[0], Document)
|
||||
assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3"
|
||||
# Second item should be a failure
|
||||
assert isinstance(second_batch.items[1], ConnectorFailure)
|
||||
assert second_batch.items[1].failed_document is not None
|
||||
assert second_batch.items[1].failed_document.document_id == "TEST-4"
|
||||
assert "Failed to process Jira issue" in second_batch.items[1].failure_message
|
||||
# Check checkpoint indicates more items
|
||||
assert second_batch.next_checkpoint.has_more is True
|
||||
assert second_batch.next_checkpoint.offset == 4
|
||||
|
||||
# Check third, empty batch
|
||||
third_batch = outputs[2]
|
||||
assert len(third_batch.items) == 0
|
||||
assert third_batch.next_checkpoint.has_more is False
|
||||
assert third_batch.next_checkpoint.offset == 4
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_issue(
|
||||
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an issue that should be skipped due to labels"""
|
||||
LABEL_TO_SKIP = "secret"
|
||||
jira_connector.labels_to_skip = {LABEL_TO_SKIP}
|
||||
|
||||
# Set up mocked issue with a label to skip
|
||||
mock_issue = create_mock_issue(
|
||||
key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP]
|
||||
)
|
||||
|
||||
# Mock search_issues to return our mock issue
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.return_value = [mock_issue]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
|
||||
|
||||
assert len(outputs) == 1
|
||||
checkpoint_output = outputs[0]
|
||||
# Check that no documents were returned
|
||||
assert len(checkpoint_output.items) == 0
|
||||
|
||||
|
||||
def test_retrieve_all_slim_documents(
|
||||
jira_connector: JiraConnector, create_mock_issue: Any
|
||||
) -> None:
|
||||
"""Test retrieving all slim documents"""
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(key="TEST-1")
|
||||
mock_issue2 = create_mock_issue(key="TEST-2")
|
||||
|
||||
# Mock search_issues to return our mock issues
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
search_issues_mock = cast(MagicMock, jira_client.search_issues)
|
||||
search_issues_mock.return_value = [mock_issue1, mock_issue2]
|
||||
|
||||
# Mock best_effort_get_field_from_issue to return the keys
|
||||
with patch(
|
||||
"onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue"
|
||||
) as mock_field:
|
||||
mock_field.side_effect = ["TEST-1", "TEST-2"]
|
||||
|
||||
# Mock build_jira_url to return URLs
|
||||
with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url:
|
||||
mock_url.side_effect = [
|
||||
"https://jira.example.com/browse/TEST-1",
|
||||
"https://jira.example.com/browse/TEST-2",
|
||||
]
|
||||
|
||||
# Call retrieve_all_slim_documents
|
||||
batches = list(jira_connector.retrieve_all_slim_documents(0, 100))
|
||||
|
||||
# Check that a batch with 2 documents was returned
|
||||
assert len(batches) == 1
|
||||
assert len(batches[0]) == 2
|
||||
assert isinstance(batches[0][0], SlimDocument)
|
||||
assert batches[0][0].id == "https://jira.example.com/browse/TEST-1"
|
||||
assert batches[0][1].id == "https://jira.example.com/browse/TEST-2"
|
||||
|
||||
# Check that search_issues was called with the right parameters
|
||||
search_issues_mock.assert_called_once()
|
||||
args, kwargs = search_issues_mock.call_args
|
||||
assert kwargs["fields"] == "key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"Jira credential appears to be expired or invalid",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your Jira token does not have sufficient permissions",
|
||||
),
|
||||
(404, ConnectorValidationError, "Jira project not found"),
|
||||
(
|
||||
429,
|
||||
ConnectorValidationError,
|
||||
"Validation failed due to Jira rate-limits being exceeded",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
jira_connector: JiraConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
error = JIRAError(status_code=status_code)
|
||||
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
project_mock = cast(MagicMock, jira_client.project)
|
||||
project_mock.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
jira_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_with_project_success(
|
||||
jira_connector: JiraConnector,
|
||||
) -> None:
|
||||
"""Test successful validation with project specified"""
|
||||
jira_client = cast(JIRA, jira_connector._jira_client)
|
||||
project_mock = cast(MagicMock, jira_client.project)
|
||||
project_mock.return_value = MagicMock()
|
||||
jira_connector.validate_connector_settings()
|
||||
project_mock.assert_called_once_with(jira_connector.jira_project)
|
||||
|
||||
|
||||
def test_validate_connector_settings_without_project_success(
|
||||
jira_base_url: str,
|
||||
) -> None:
|
||||
"""Test successful validation without project specified"""
|
||||
connector = JiraConnector(jira_base_url=jira_base_url)
|
||||
connector._jira_client = MagicMock()
|
||||
connector._jira_client.projects.return_value = [MagicMock()]
|
||||
|
||||
connector.validate_connector_settings()
|
||||
connector._jira_client.projects.assert_called_once()
|
||||
@@ -7,8 +7,7 @@ import pytest
|
||||
from jira.resources import Issue
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from onyx.connectors.onyx_jira.connector import _perform_jql_search
|
||||
from onyx.connectors.onyx_jira.connector import process_jira_issue
|
||||
from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -80,22 +79,14 @@ def test_fetch_jira_issues_batch_small_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small]
|
||||
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
|
||||
assert len(docs) == 1
|
||||
doc = docs[0]
|
||||
assert doc is not None # Type assertion for mypy
|
||||
assert doc.id.endswith("/SMALL-1")
|
||||
assert doc.sections[0].text is not None
|
||||
assert "Small description" in doc.sections[0].text
|
||||
assert "Small comment 1" in doc.sections[0].text
|
||||
assert "Small comment 2" in doc.sections[0].text
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
assert docs[0].sections[0].text is not None
|
||||
assert "Small description" in docs[0].sections[0].text
|
||||
assert "Small comment 1" in docs[0].sections[0].text
|
||||
assert "Small comment 2" in docs[0].sections[0].text
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_large_ticket(
|
||||
@@ -105,13 +96,7 @@ def test_fetch_jira_issues_batch_large_ticket(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_large]
|
||||
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 1
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
|
||||
@@ -124,18 +109,10 @@ def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
doc = docs[0]
|
||||
assert doc is not None # Type assertion for mypy
|
||||
assert doc.id.endswith("/SMALL-1")
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
|
||||
|
||||
@patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
|
||||
@@ -147,12 +124,6 @@ def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
# First get the issues via pagination
|
||||
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
|
||||
assert len(issues) == 2
|
||||
|
||||
# Then process each issue
|
||||
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
|
||||
docs = [doc for doc in docs if doc is not None] # Filter out None values
|
||||
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
|
||||
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SingleConnectorCallOutput(BaseModel, Generic[CT]):
|
||||
items: list[Document | ConnectorFailure]
|
||||
next_checkpoint: CT
|
||||
|
||||
|
||||
def load_everything_from_checkpoint_connector(
|
||||
connector: CheckpointConnector[CT],
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> list[SingleConnectorCallOutput[CT]]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = cast(CT, connector.build_dummy_checkpoint())
|
||||
outputs: list[SingleConnectorCallOutput[CT]] = []
|
||||
while checkpoint.has_more:
|
||||
items: list[Document | ConnectorFailure] = []
|
||||
doc_batch_generator = CheckpointOutputWrapper[CT]()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
items.append(failure)
|
||||
if document is not None:
|
||||
items.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
outputs.append(
|
||||
SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint)
|
||||
)
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > _ITERATION_LIMIT:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return outputs
|
||||
@@ -1,472 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import call
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskClient
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zendesk_client() -> MagicMock:
|
||||
"""Create a mock Zendesk client"""
|
||||
mock = MagicMock(spec=ZendeskClient)
|
||||
mock.base_url = "https://test.zendesk.com/api/v2"
|
||||
mock.auth = ("test@example.com/token", "test_token")
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zendesk_connector(
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with mocked client"""
|
||||
connector = ZendeskConnector(content_type="articles")
|
||||
connector.client = mock_zendesk_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unmocked_zendesk_connector() -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with unmocked client"""
|
||||
zendesk_connector = ZendeskConnector(content_type="articles")
|
||||
zendesk_connector.client = ZendeskClient(
|
||||
"test", "test@example.com/token", "test_token"
|
||||
)
|
||||
yield zendesk_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_article() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_article(
|
||||
id: int = 1,
|
||||
title: str = "Test Article",
|
||||
body: str = "Test Content",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
author_id: str = "123",
|
||||
label_names: list[str] | None = None,
|
||||
draft: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock article"""
|
||||
return {
|
||||
"id": id,
|
||||
"title": title,
|
||||
"body": body,
|
||||
"updated_at": updated_at,
|
||||
"author_id": author_id,
|
||||
"label_names": label_names or [],
|
||||
"draft": draft,
|
||||
"html_url": f"https://test.zendesk.com/hc/en-us/articles/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_article
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_ticket() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_ticket(
|
||||
id: int = 1,
|
||||
subject: str = "Test Ticket",
|
||||
description: str = "Test Description",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
submitter_id: str = "123",
|
||||
status: str = "open",
|
||||
priority: str = "normal",
|
||||
tags: list[str] | None = None,
|
||||
ticket_type: str = "question",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock ticket"""
|
||||
return {
|
||||
"id": id,
|
||||
"subject": subject,
|
||||
"description": description,
|
||||
"updated_at": updated_at,
|
||||
"submitter": submitter_id,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"tags": tags or [],
|
||||
"type": ticket_type,
|
||||
"url": f"https://test.zendesk.com/agent/tickets/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_ticket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_author() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_author(
|
||||
id: str = "123",
|
||||
name: str = "Test User",
|
||||
email: str = "test@example.com",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock author"""
|
||||
return {
|
||||
"user": {
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
}
|
||||
}
|
||||
|
||||
return _create_mock_author
|
||||
|
||||
|
||||
def test_load_from_checkpoint_articles_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading articles from checkpoint - happy path"""
|
||||
# Set up mock responses
|
||||
mock_article1 = create_mock_article(id=1, title="Article 1")
|
||||
mock_article2 = create_mock_article(id=2, title="Article 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page
|
||||
{
|
||||
"articles": [mock_article1, mock_article2],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "article:1"
|
||||
assert doc1.semantic_identifier == "Article 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "article:2"
|
||||
assert doc2.semantic_identifier == "Article 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_tickets_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading tickets from checkpoint - happy path"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses
|
||||
mock_ticket1 = create_mock_ticket(id=1, subject="Ticket 1")
|
||||
mock_ticket2 = create_mock_ticket(id=2, subject="Ticket 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page
|
||||
{
|
||||
"tickets": [mock_ticket1, mock_ticket2],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
# Fourth call: comments page
|
||||
{"comments": []},
|
||||
# Fifth call: comments page
|
||||
{"comments": []},
|
||||
]
|
||||
|
||||
zendesk_connector.client = mock_zendesk_client
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
print(doc1, type(doc1))
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "zendesk_ticket_1"
|
||||
assert doc1.semantic_identifier == "Ticket #1: Ticket 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "zendesk_ticket_2"
|
||||
assert doc2.semantic_identifier == "Ticket #2: Ticket 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
unmocked_zendesk_connector: ZendeskConnector,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
zendesk_connector = unmocked_zendesk_connector
|
||||
# Set up mock responses
|
||||
mock_article = create_mock_article()
|
||||
mock_author = create_mock_author()
|
||||
author_response = MagicMock()
|
||||
author_response.status_code = 200
|
||||
author_response.json.return_value = mock_author
|
||||
|
||||
# Create mock responses for requests.get
|
||||
rate_limit_response = MagicMock()
|
||||
rate_limit_response.status_code = 429
|
||||
rate_limit_response.headers = {"Retry-After": "60"}
|
||||
rate_limit_response.raise_for_status.side_effect = HTTPError(
|
||||
response=rate_limit_response
|
||||
)
|
||||
|
||||
success_response = MagicMock()
|
||||
success_response.status_code = 200
|
||||
success_response.json.return_value = {
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Mock requests.get to simulate rate limit then success
|
||||
with patch("onyx.connectors.zendesk.connector.requests.get") as mock_get:
|
||||
mock_get.side_effect = [
|
||||
# First call: content tags
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"records": [], "meta": {"has_more": False}},
|
||||
),
|
||||
# Second call: articles page (rate limited)
|
||||
rate_limit_response,
|
||||
# Third call: articles page (after rate limit)
|
||||
success_response,
|
||||
# Fourth call: author info
|
||||
author_response,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch("onyx.connectors.zendesk.connector.time.sleep") as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
zendesk_connector, 0, end_time
|
||||
)
|
||||
mock_sleep.assert_has_calls([call(60), call(0.1)])
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "article:1"
|
||||
|
||||
# Verify the requests were made with correct parameters
|
||||
assert mock_get.call_count == 4
|
||||
# First call should be for content tags
|
||||
args, kwargs = mock_get.call_args_list[0]
|
||||
assert "guide/content_tags" in args[0]
|
||||
# Second call should be for articles (rate limited)
|
||||
args, kwargs = mock_get.call_args_list[1]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Third call should be for articles (success)
|
||||
args, kwargs = mock_get.call_args_list[2]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Fourth call should be for author info
|
||||
args, kwargs = mock_get.call_args_list[3]
|
||||
assert "users/123" in args[0]
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_response(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with empty response"""
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: empty articles page
|
||||
{
|
||||
"articles": [],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_article(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an article that should be skipped"""
|
||||
# Set up mock responses with a draft article
|
||||
mock_article = create_mock_article(draft=True)
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page with draft article
|
||||
{
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_ticket(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with a deleted ticket"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses with a deleted ticket
|
||||
mock_ticket = create_mock_ticket(status="deleted")
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page with deleted ticket
|
||||
{
|
||||
"tickets": [mock_ticket],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"Your Zendesk credentials appear to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your Zendesk token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"Zendesk resource not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
error = HTTPError(response=mock_response)
|
||||
|
||||
mock_zendesk_client = cast(MagicMock, zendesk_connector.client)
|
||||
mock_zendesk_client.make_request.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
print("excinfo", excinfo)
|
||||
zendesk_connector.validate_connector_settings()
|
||||
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Mock successful API response
|
||||
mock_zendesk_client.make_request.return_value = {
|
||||
"articles": [],
|
||||
"meta": {"has_more": False},
|
||||
}
|
||||
|
||||
zendesk_connector.validate_connector_settings()
|
||||
@@ -1,16 +1,10 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a context variable for testing
|
||||
@@ -154,237 +148,3 @@ def test_multiple_background_tasks() -> None:
|
||||
|
||||
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
|
||||
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
|
||||
|
||||
|
||||
def test_thread_safe_dict_basic_operations() -> None:
|
||||
"""Test basic operations of ThreadSafeDict"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
|
||||
# Test setting and getting
|
||||
d["a"] = 1
|
||||
assert d["a"] == 1
|
||||
|
||||
# Test get with default
|
||||
assert d.get("a", None) == 1
|
||||
assert d.get("b", 2) == 2
|
||||
|
||||
# Test deletion
|
||||
del d["a"]
|
||||
assert "a" not in d
|
||||
|
||||
# Test length
|
||||
d["x"] = 10
|
||||
d["y"] = 20
|
||||
assert len(d) == 2
|
||||
|
||||
# Test iteration
|
||||
keys = sorted(d.keys())
|
||||
assert keys == ["x", "y"]
|
||||
|
||||
# Test items and values
|
||||
assert dict(d.items()) == {"x": 10, "y": 20}
|
||||
assert sorted(d.values()) == [10, 20]
|
||||
|
||||
|
||||
def test_thread_safe_dict_concurrent_access() -> None:
|
||||
"""Test ThreadSafeDict with concurrent access from multiple threads"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
num_threads = 10
|
||||
iterations = 1000
|
||||
|
||||
def increment_values() -> None:
|
||||
for i in range(iterations):
|
||||
key = str(i % 5) # Use 5 different keys
|
||||
# Get current value or 0 if not exists, increment, then store
|
||||
d[key] = d.get(key, 0) + 1
|
||||
|
||||
# Create and start threads
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=increment_values)
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify results
|
||||
# Each key should have been incremented (num_threads * iterations) / 5 times
|
||||
expected_value = (num_threads * iterations) // 5
|
||||
for i in range(5):
|
||||
assert d[str(i)] == expected_value
|
||||
|
||||
|
||||
def test_thread_safe_dict_bulk_operations() -> None:
|
||||
"""Test bulk operations of ThreadSafeDict"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
|
||||
# Test update with dict
|
||||
d.update({"a": 1, "b": 2})
|
||||
assert dict(d.items()) == {"a": 1, "b": 2}
|
||||
|
||||
# Test update with kwargs
|
||||
d.update(c=3, d=4)
|
||||
assert dict(d.items()) == {"a": 1, "b": 2, "c": 3, "d": 4}
|
||||
|
||||
# Test clear
|
||||
d.clear()
|
||||
assert len(d) == 0
|
||||
|
||||
|
||||
def test_thread_safe_dict_concurrent_bulk_operations() -> None:
|
||||
"""Test ThreadSafeDict with concurrent bulk operations"""
|
||||
d = ThreadSafeDict[str, int]()
|
||||
num_threads = 5
|
||||
|
||||
def bulk_update(start: int) -> None:
|
||||
# Each thread updates with its own range of numbers
|
||||
updates = {str(i): i for i in range(start, start + 20)}
|
||||
d.update(updates)
|
||||
time.sleep(0.01) # Add some delay to increase chance of thread overlap
|
||||
|
||||
# Run updates concurrently
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(bulk_update, i * 20) for i in range(num_threads)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Verify results
|
||||
assert len(d) == num_threads * 20
|
||||
# Verify all numbers from 0 to (num_threads * 20) are present
|
||||
for i in range(num_threads * 20):
|
||||
assert d[str(i)] == i
|
||||
|
||||
|
||||
def test_thread_safe_dict_atomic_operations() -> None:
|
||||
"""Test atomic operations with ThreadSafeDict's lock"""
|
||||
d = ThreadSafeDict[str, list[int]]()
|
||||
d["numbers"] = []
|
||||
|
||||
def append_numbers(start: int) -> None:
|
||||
numbers = d["numbers"]
|
||||
with d.lock:
|
||||
for i in range(start, start + 5):
|
||||
numbers.append(i)
|
||||
time.sleep(0.001) # Add delay to increase chance of thread overlap
|
||||
d["numbers"] = numbers
|
||||
|
||||
# Run concurrent append operations
|
||||
threads = []
|
||||
for i in range(4): # 4 threads, each adding 5 numbers
|
||||
t = threading.Thread(target=append_numbers, args=(i * 5,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify results
|
||||
numbers = d["numbers"]
|
||||
assert len(numbers) == 20 # 4 threads * 5 numbers each
|
||||
assert sorted(numbers) == list(range(20)) # All numbers 0-19 should be present
|
||||
|
||||
|
||||
def test_parallel_yield_basic() -> None:
|
||||
"""Test that parallel_yield correctly yields values from multiple generators."""
|
||||
|
||||
def make_gen(values: list[int], delay: float) -> Generator[int, None, None]:
|
||||
for v in values:
|
||||
time.sleep(delay)
|
||||
yield v
|
||||
|
||||
# Create generators with different delays
|
||||
gen1 = make_gen([1, 4, 7], 0.1) # Slower generator
|
||||
gen2 = make_gen([2, 5, 8], 0.05) # Faster generator
|
||||
gen3 = make_gen([3, 6, 9], 0.15) # Slowest generator
|
||||
|
||||
# Collect results with timestamps
|
||||
results: list[tuple[float, int]] = []
|
||||
start_time = time.time()
|
||||
|
||||
for value in parallel_yield([gen1, gen2, gen3]):
|
||||
results.append((time.time() - start_time, value))
|
||||
|
||||
# Verify all values were yielded
|
||||
assert sorted(v for _, v in results) == list(range(1, 10))
|
||||
|
||||
# Verify that faster generators yielded earlier
|
||||
# Group results by generator (values 1,4,7 are gen1, 2,5,8 are gen2, 3,6,9 are gen3)
|
||||
gen1_times = [t for t, v in results if v in (1, 4, 7)]
|
||||
gen2_times = [t for t, v in results if v in (2, 5, 8)]
|
||||
gen3_times = [t for t, v in results if v in (3, 6, 9)]
|
||||
|
||||
# Average times for each generator
|
||||
avg_gen1 = sum(gen1_times) / len(gen1_times)
|
||||
avg_gen2 = sum(gen2_times) / len(gen2_times)
|
||||
avg_gen3 = sum(gen3_times) / len(gen3_times)
|
||||
|
||||
# Verify gen2 (fastest) has lowest average time
|
||||
assert avg_gen2 < avg_gen1
|
||||
assert avg_gen2 < avg_gen3
|
||||
|
||||
|
||||
def test_parallel_yield_empty_generators() -> None:
|
||||
"""Test parallel_yield with empty generators."""
|
||||
|
||||
def empty_gen() -> Iterator[int]:
|
||||
if False:
|
||||
yield 0 # Makes this a generator function
|
||||
|
||||
gens = [empty_gen() for _ in range(3)]
|
||||
results = list(parallel_yield(gens))
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_parallel_yield_different_lengths() -> None:
|
||||
"""Test parallel_yield with generators of different lengths."""
|
||||
|
||||
def make_gen(count: int) -> Iterator[int]:
|
||||
for i in range(count):
|
||||
yield i
|
||||
time.sleep(0.01) # Small delay to ensure concurrent execution
|
||||
|
||||
gens = [
|
||||
make_gen(1), # Yields: [0]
|
||||
make_gen(3), # Yields: [0, 1, 2]
|
||||
make_gen(2), # Yields: [0, 1]
|
||||
]
|
||||
|
||||
results = list(parallel_yield(gens))
|
||||
assert len(results) == 6 # Total number of items from all generators
|
||||
assert sorted(results) == [0, 0, 0, 1, 1, 2]
|
||||
|
||||
|
||||
def test_parallel_yield_exception_handling() -> None:
|
||||
"""Test parallel_yield handles exceptions in generators properly."""
|
||||
|
||||
def failing_gen() -> Iterator[int]:
|
||||
yield 1
|
||||
raise ValueError("Generator failure")
|
||||
|
||||
def normal_gen() -> Iterator[int]:
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
gens = [failing_gen(), normal_gen()]
|
||||
|
||||
with pytest.raises(ValueError, match="Generator failure"):
|
||||
list(parallel_yield(gens))
|
||||
|
||||
|
||||
def test_parallel_yield_non_blocking() -> None:
|
||||
"""Test parallel_yield with non-blocking generators (simple ranges)."""
|
||||
|
||||
def range_gen(start: int, end: int) -> Iterator[int]:
|
||||
for i in range(start, end):
|
||||
yield i
|
||||
|
||||
# Create three overlapping ranges
|
||||
gens = [range_gen(0, 100), range_gen(100, 200), range_gen(200, 300)]
|
||||
|
||||
results = list(parallel_yield(gens))
|
||||
|
||||
# Verify no values are missing
|
||||
assert len(results) == 300 # Should have all values from 0 to 299
|
||||
assert sorted(results) == list(range(300))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user