mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
16 Commits
checkpoint
...
bugfixfix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65bf67b2c1 | ||
|
|
55fb5c3ca5 | ||
|
|
99546e4a4d | ||
|
|
c25d56f4a5 | ||
|
|
35f3f4f120 | ||
|
|
25b69a8aca | ||
|
|
1b7d710b2a | ||
|
|
ae3d3db3f4 | ||
|
|
fb79a9e700 | ||
|
|
587ba11bbc | ||
|
|
fce81ebb60 | ||
|
|
61facfb0a8 | ||
|
|
52b96854a2 | ||
|
|
d123713c00 | ||
|
|
775c847f82 | ||
|
|
6d330131fd |
@@ -84,7 +84,7 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = INFO
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -64,7 +68,7 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
# continue on error with individual tenant
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -87,11 +87,14 @@ async def get_or_provision_tenant(
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
@@ -116,10 +119,6 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
@@ -561,7 +560,3 @@ async def assign_tenant_to_user(
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -65,11 +65,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
try:
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
|
||||
"This is not a critical error and the model server will continue to run."
|
||||
)
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -440,6 +443,14 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.gpu_utils import fast_gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,7 +88,9 @@ class Answer:
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
allow_agent_reranking = (
|
||||
fast_gpu_status_request(indexing=False) or using_cloud_reranking
|
||||
)
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
|
||||
@@ -388,6 +388,10 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
|
||||
@@ -114,6 +114,7 @@ class ConfluenceConnector(
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@@ -158,6 +159,9 @@ class ConfluenceConnector(
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -233,7 +237,9 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
@@ -264,6 +270,7 @@ class ConfluenceConnector(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
@@ -304,13 +311,14 @@ class ConfluenceConnector(
|
||||
if "version" in page and "by" in page["version"]:
|
||||
author = page["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
primary_owners.append(BasicExpertInfo(display_name=display_name))
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners.append(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
id=page_url,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -373,6 +381,7 @@ class ConfluenceConnector(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
@@ -498,10 +498,12 @@ class OnyxConfluence:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
|
||||
"but we have logic to work around it - don't worry this isn't"
|
||||
f" causing an issue. Start: {new_start}, Previous Start: "
|
||||
f"{previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
|
||||
@@ -112,6 +112,7 @@ def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
allow_images: bool,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
@@ -119,7 +120,7 @@ def process_attachment(
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return AttachmentProcessingResult(
|
||||
@@ -138,7 +139,14 @@ def process_attachment(
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if not media_type.startswith("image/"):
|
||||
if media_type.startswith("image/"):
|
||||
if not allow_images:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error="Image downloading is not enabled",
|
||||
)
|
||||
else:
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
@@ -294,6 +302,7 @@ def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
allow_images: bool,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
@@ -309,7 +318,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_id)
|
||||
result = process_attachment(confluence_client, attachment, page_id, allow_images)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -184,6 +185,8 @@ def instantiate_connector(
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,24 +219,34 @@ def _process_file(
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images.
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Extract text and images from the file
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
extraction_result = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
)
|
||||
|
||||
# Merge file-specific metadata (from file content) with provided metadata
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
|
||||
)
|
||||
metadata.update(extraction_result.metadata)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if text_content.strip():
|
||||
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
|
||||
if extraction_result.text_content.strip():
|
||||
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
|
||||
sections.append(
|
||||
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
|
||||
)
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
extraction_result.embedded_images, start=1
|
||||
):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image reference
|
||||
try:
|
||||
|
||||
@@ -15,6 +15,7 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -86,6 +87,8 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
file: dict[str, Any],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
@@ -101,6 +104,8 @@ def _convert_single_file(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
allow_images=allow_images,
|
||||
size_threshold=size_threshold,
|
||||
)
|
||||
|
||||
|
||||
@@ -234,6 +239,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._retrieved_ids: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
@@ -900,6 +911,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
@@ -1097,7 +1110,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_service.files().list(pageSize=1, fields="files(id)").execute()
|
||||
|
||||
if isinstance(self._creds, ServiceAccountCredentials):
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
# default is ~17mins of retries, don't do that here since this is called from
|
||||
# the UI
|
||||
retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service)
|
||||
|
||||
except HttpError as e:
|
||||
status_code = e.resp.status if e.resp else None
|
||||
|
||||
@@ -76,9 +76,10 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _extract_sections_basic(
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
@@ -87,6 +88,10 @@ def _extract_sections_basic(
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
@@ -207,6 +212,8 @@ def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
@@ -234,9 +241,24 @@ def convert_drive_item_to_document(
|
||||
f"Error in advanced parsing: {e}. Falling back to basic extraction."
|
||||
)
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logger.warning(
|
||||
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service())
|
||||
sections = _download_and_extract_sections_basic(
|
||||
file, drive_service(), allow_images
|
||||
)
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
@@ -36,12 +37,12 @@ def _generate_time_range_filter(
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_range_filter += (
|
||||
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
|
||||
)
|
||||
if end is not None:
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
@@ -17,9 +17,12 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
#
|
||||
# NOTE: We previously tried to combat this here by adding a very
|
||||
# long retry period (~20 minutes of trying, one request a minute.)
|
||||
# This is no longer necessary due to checkpointing.
|
||||
add_retries = retry_builder(tries=5, max_delay=10)
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
@@ -37,14 +40,14 @@ class GoogleFields(str, Enum):
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 10
|
||||
max_attempts = 6
|
||||
attempt = 1
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 1. Other things are also requesting from the Drive/Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
@@ -60,6 +59,10 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
@@ -227,7 +230,7 @@ class CheckpointConnector(BaseConnector[CT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -438,7 +438,11 @@ def _get_all_doc_ids(
|
||||
|
||||
class ProcessedSlackMessage(BaseModel):
|
||||
doc: Document | None
|
||||
thread_ts: str | None
|
||||
# if the message is part of a thread, this is the thread_ts
|
||||
# otherwise, this is the message_ts. Either way, will be a unique identifier.
|
||||
# In the future, if the message becomes a thread, then the thread_ts
|
||||
# will be set to the message_ts.
|
||||
thread_or_message_ts: str
|
||||
failure: ConnectorFailure | None
|
||||
|
||||
|
||||
@@ -452,6 +456,7 @@ def _process_message(
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
@@ -467,16 +472,18 @@ def _process_message(
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_ts=thread_ts,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
@@ -616,7 +623,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
for future in as_completed(futures):
|
||||
processed_slack_message = future.result()
|
||||
doc = processed_slack_message.doc
|
||||
thread_ts = processed_slack_message.thread_ts
|
||||
thread_or_message_ts = processed_slack_message.thread_or_message_ts
|
||||
failure = processed_slack_message.failure
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
@@ -624,11 +631,13 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
if thread_or_message_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
assert thread_ts, "found non-None doc with None thread_ts"
|
||||
seen_thread_ts.add(thread_ts)
|
||||
assert (
|
||||
thread_or_message_ts
|
||||
), "found non-None doc with None thread_or_message_ts"
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -26,6 +35,7 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -53,10 +63,22 @@ class ZendeskClient:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
elif (
|
||||
response.status_code == 403
|
||||
and response.json().get("error") == "SupportProductInactive"
|
||||
):
|
||||
return response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
meta: dict[str, Any]
|
||||
has_more: bool
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
@@ -82,11 +104,9 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
)
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
@@ -98,10 +118,30 @@ def _get_articles(
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_article_page(
|
||||
client: ZendeskClient,
|
||||
start_time: int | None = None,
|
||||
after_cursor: str | None = None,
|
||||
page_size: int = MAX_PAGE_SIZE,
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if after_cursor is not None:
|
||||
params["page[after]"] = after_cursor
|
||||
|
||||
data = client.make_request("help_center/articles", params)
|
||||
return ZendeskPageResponse(
|
||||
data=data["articles"],
|
||||
meta=data["meta"],
|
||||
has_more=bool(data["meta"].get("has_more", False)),
|
||||
)
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
@@ -114,6 +154,27 @@ def _get_tickets(
|
||||
break
|
||||
|
||||
|
||||
# TODO: maybe these don't need to be their own functions?
|
||||
def _get_tickets_page(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
|
||||
# in my local testing with very few tickets. We'll look into it if this becomes an
|
||||
# issue in larger deployments
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
if data.get("error") == "SupportProductInactive":
|
||||
raise ValueError(
|
||||
"Zendesk Support Product is not active for this account, No tickets to index"
|
||||
)
|
||||
return ZendeskPageResponse(
|
||||
data=data["tickets"],
|
||||
meta={"end_time": data["end_time"]},
|
||||
has_more=not bool(data.get("end_of_stream", False)),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
@@ -278,13 +339,22 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# We use cursor-based paginated retrieval for articles
|
||||
after_cursor_articles: str | None
|
||||
|
||||
# We use timestamp-based paginated retrieval for tickets
|
||||
next_start_time_tickets: int | None
|
||||
|
||||
cached_author_map: dict[str, BasicExpertInfo] | None
|
||||
cached_content_tags: dict[str, str] | None
|
||||
|
||||
|
||||
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
@@ -304,33 +374,50 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
if checkpoint.cached_content_tags is None:
|
||||
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
|
||||
return checkpoint # save the content tags to the checkpoint
|
||||
self.content_tags = checkpoint.cached_content_tags
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
|
||||
return checkpoint
|
||||
elif self.content_type == "tickets":
|
||||
yield from self._poll_tickets(start)
|
||||
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
|
||||
return checkpoint
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
def _retrieve_articles(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
after_cursor = checkpoint.after_cursor_articles
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
doc_batch = []
|
||||
response = _get_article_page(
|
||||
self.client,
|
||||
start_time=int(start) if start else None,
|
||||
after_cursor=after_cursor,
|
||||
)
|
||||
articles = response.data
|
||||
has_more = response.has_more
|
||||
after_cursor = response.meta.get("after_cursor")
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
@@ -342,66 +429,109 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
try:
|
||||
new_author_map, document = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{article.get('id')}",
|
||||
document_link=article.get("html_url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
doc_batch.append(document)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Sometimes no documents are retrieved, but the cursor
|
||||
# is still updated so the connector makes progress.
|
||||
yield from doc_batch
|
||||
checkpoint.after_cursor_articles = after_cursor
|
||||
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def _retrieve_tickets(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
doc_batch: list[Document] = []
|
||||
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
|
||||
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
|
||||
tickets = ticket_response.data
|
||||
has_more = ticket_response.has_more
|
||||
next_start_time = ticket_response.meta["end_time"]
|
||||
for ticket in tickets:
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
try:
|
||||
new_author_map, document = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{ticket.get('id')}",
|
||||
document_link=ticket.get("url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
yield from doc_batch
|
||||
checkpoint.next_start_time_tickets = next_start_time
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -441,10 +571,51 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
try:
|
||||
_get_article_page(self.client, start_time=0)
|
||||
except HTTPError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.response.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.response.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
elif e.response.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Zendesk resource not found (HTTP 404)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint(
|
||||
after_cursor_articles=None,
|
||||
next_start_time_tickets=None,
|
||||
cached_author_map=None,
|
||||
cached_content_tags=None,
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
connector.load_credentials(
|
||||
@@ -457,6 +628,8 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
one_day_ago, current, connector.build_dummy_checkpoint()
|
||||
)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -555,6 +555,28 @@ def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> None:
|
||||
"""Deletes all document by connector credential pair entries for a specific connector and credential.
|
||||
This is primarily used during connector deletion to ensure all references are removed
|
||||
before deleting the connector itself. This is crucial because connector_id is part of the
|
||||
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
|
||||
would otherwise try to set the foreign key to NULL, which fails for primary keys.
|
||||
|
||||
NOTE: Does not commit the transaction, this must be done by the caller.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ import re
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
@@ -219,7 +219,7 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
|
||||
def read_pdf_file(
|
||||
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
|
||||
) -> tuple[str, dict, list[tuple[bytes, str]]]:
|
||||
) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Returns the text, basic PDF metadata, and optionally extracted images.
|
||||
"""
|
||||
@@ -282,13 +282,13 @@ def read_pdf_file(
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any],
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: List[Tuple[bytes, str]] = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
doc = docx.Document(file)
|
||||
|
||||
@@ -426,14 +426,22 @@ def extract_file_text(
|
||||
return ""
|
||||
|
||||
|
||||
class ExtractionResult(NamedTuple):
|
||||
"""Structured result from text and image extraction from various file types."""
|
||||
|
||||
text_content: str
|
||||
embedded_images: Sequence[tuple[bytes, str]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
def extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -442,7 +450,9 @@ def extract_text_and_images(
|
||||
# If the user doesn't want embedded images, unstructured is fine
|
||||
file.seek(0)
|
||||
text_content = unstructured_to_text(file, file_name)
|
||||
return (text_content, [])
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
@@ -450,54 +460,76 @@ def extract_text_and_images(
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
return (text_content, images)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
|
||||
# PDF example: we do not show complicated PDF image extraction here
|
||||
# so we simply extract text for now and skip images.
|
||||
if extension == ".pdf":
|
||||
file.seek(0)
|
||||
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
|
||||
return (text_content, images)
|
||||
text_content, pdf_metadata, images = read_pdf_file(
|
||||
file, pdf_pass, extract_images=True
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata=pdf_metadata
|
||||
)
|
||||
|
||||
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
|
||||
# You can do something similar to docx if needed.
|
||||
if extension == ".pptx":
|
||||
file.seek(0)
|
||||
return (pptx_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=pptx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".xlsx":
|
||||
file.seek(0)
|
||||
return (xlsx_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=xlsx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".eml":
|
||||
file.seek(0)
|
||||
return (eml_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=eml_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".epub":
|
||||
file.seek(0)
|
||||
return (epub_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=epub_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".html":
|
||||
file.seek(0)
|
||||
return (parse_html_page_basic(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=parse_html_page_basic(file),
|
||||
embedded_images=[],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
file.seek(0)
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, _ = read_text_file(
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return (text_content_raw, [])
|
||||
return ExtractionResult(
|
||||
text_content=text_content_raw,
|
||||
embedded_images=[],
|
||||
metadata=file_metadata,
|
||||
)
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
file.seek(0)
|
||||
return ("", [])
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
|
||||
return ("", [])
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
|
||||
@@ -170,7 +170,8 @@ def handle_message(
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
|
||||
|
||||
if respond_tag_only and not bypass_filters:
|
||||
# NOTE: always respond in the DMs, as long the default config is not disabled.
|
||||
if respond_tag_only and not bypass_filters and not is_bot_dm:
|
||||
logger.info(
|
||||
"Skipping message since the channel is configured such that "
|
||||
"OnyxBot only responds to tags"
|
||||
|
||||
@@ -261,9 +261,6 @@ def create_bot(
|
||||
# Create a default Slack channel config
|
||||
default_channel_config = ChannelConfig(
|
||||
channel_name=None,
|
||||
respond_member_group_list=[],
|
||||
answer_filters=[],
|
||||
follow_up_tags=[],
|
||||
respond_tag_only=True,
|
||||
)
|
||||
insert_slack_channel_config(
|
||||
@@ -371,7 +368,9 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Fetches all channels in the Slack workspace using the conversations_list API.
|
||||
This includes both public and private channels that are visible to the app,
|
||||
not just the ones the bot is a member of.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
@@ -386,20 +385,20 @@ def get_all_channels_from_slack_api(
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
# Use users_conversations with limited pagination
|
||||
# Use conversations_list to get all channels in the workspace (including ones the bot is not a member of)
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.users_conversations(
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.users_conversations(
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
|
||||
@@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||
)
|
||||
gpu_available = gpu_status_request()
|
||||
gpu_available = gpu_status_request(indexing=True)
|
||||
logger.info(f"GPU available: {gpu_available}")
|
||||
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
@@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool = True) -> bool:
|
||||
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
if indexing:
|
||||
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
|
||||
else:
|
||||
@@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool:
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
|
||||
raise # Re-raise exception to trigger a retry
|
||||
|
||||
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool) -> bool:
|
||||
return _get_gpu_status_from_model_server(indexing)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def fast_gpu_status_request(indexing: bool) -> bool:
|
||||
"""For use in sync flows, where we don't want to retry / we want to cache this."""
|
||||
return gpu_status_request(indexing=indexing)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -7,15 +8,16 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.utils import AttachmentProcessingResult
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector() -> ConfluenceConnector:
|
||||
def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
space=space,
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
)
|
||||
@@ -32,14 +34,15 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@pytest.mark.skip(reason="Skipping this test")
|
||||
def test_confluence_connector_basic(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
@@ -50,15 +53,14 @@ def test_confluence_connector_basic(
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
txt_doc: Document | None = None
|
||||
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
|
||||
page_doc = doc
|
||||
elif ".txt" in doc.semantic_identifier:
|
||||
txt_doc = doc
|
||||
elif doc.semantic_identifier == "Page Within A Page":
|
||||
page_within_a_page_doc = doc
|
||||
else:
|
||||
pass
|
||||
|
||||
assert page_within_a_page_doc is not None
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
@@ -79,7 +81,7 @@ def test_confluence_connector_basic(
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_doc.sections) == 1
|
||||
assert len(page_doc.sections) == 2 # page text + attachment text
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
@@ -88,13 +90,65 @@ def test_confluence_connector_basic(
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
assert txt_doc is not None
|
||||
assert txt_doc.semantic_identifier == "small-file.txt"
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
text_attachment_section = page_doc.sections[1]
|
||||
assert text_attachment_section.text == "small"
|
||||
assert text_attachment_section.link
|
||||
assert text_attachment_section.link.endswith("small-file.txt")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_confluence_connector_skip_images(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 8
|
||||
|
||||
|
||||
def mock_process_image_attachment(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AttachmentProcessingResult:
|
||||
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
|
||||
be done as a rule to begin with, but life is not perfect. Fix it later"""
|
||||
|
||||
return AttachmentProcessingResult(
|
||||
text="Hi_text",
|
||||
file_name="Hi_filename",
|
||||
error=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@patch(
|
||||
"onyx.connectors.confluence.utils._process_image_attachment",
|
||||
side_effect=mock_process_image_attachment,
|
||||
)
|
||||
def test_confluence_connector_allow_images(
|
||||
mock_get_api_key: MagicMock,
|
||||
mock_process_image_attachment: MagicMock,
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(True)
|
||||
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 12
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
@@ -136,3 +137,22 @@ def google_drive_service_acct_connector_factory() -> (
|
||||
return connector
|
||||
|
||||
return _connector_factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_resource_limits() -> None:
|
||||
# the google sdk is aggressive about using up file descriptors and
|
||||
# macos is stingy ... these tests will fail randomly unless the descriptor limit is raised
|
||||
RLIMIT_MINIMUM = 2048
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
desired_soft = min(RLIMIT_MINIMUM, hard) # Pick your target here
|
||||
|
||||
print(f"Open file limit: soft={soft} hard={hard} soft_required={RLIMIT_MINIMUM}")
|
||||
|
||||
if soft < desired_soft:
|
||||
print(f"Raising open file limit: {soft} -> {desired_soft}")
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_soft, hard))
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
print(f"New open file limit: soft={soft} hard={hard}")
|
||||
return
|
||||
|
||||
@@ -161,10 +161,14 @@ def _get_expected_file_content(file_id: int) -> str:
|
||||
return file_text_template.format(file_id)
|
||||
|
||||
|
||||
def assert_retrieved_docs_match_expected(
|
||||
def assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs: list[Document],
|
||||
expected_file_ids: Sequence[int],
|
||||
) -> None:
|
||||
"""NOTE: as far as i can tell this does NOT assert for an exact match.
|
||||
it only checks to see if that the expected file id's are IN the retrieved doc list
|
||||
"""
|
||||
|
||||
expected_file_names = {
|
||||
file_name_template.format(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
@@ -175,7 +179,7 @@ def assert_retrieved_docs_match_expected(
|
||||
retrieved_docs.sort(key=lambda x: x.semantic_identifier)
|
||||
|
||||
for doc in retrieved_docs:
|
||||
print(f"doc.semantic_identifier: {doc.semantic_identifier}")
|
||||
print(f"retrieved doc: doc.semantic_identifier={doc.semantic_identifier}")
|
||||
|
||||
# Filter out invalid prefixes to prevent different tests from interfering with each other
|
||||
valid_retrieved_docs = [
|
||||
|
||||
@@ -7,7 +7,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -62,7 +62,7 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -128,7 +128,7 @@ def test_include_my_drives_only(
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -161,7 +161,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -198,7 +198,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -241,7 +241,7 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -271,7 +271,7 @@ def test_personal_folders_only(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -70,12 +70,39 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_include_shared_drives_only_with_size_threshold(
|
||||
mock_get_api_key: MagicMock,
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_include_shared_drives_only_with_size_threshold")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_folder_urls=None,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
# this threshold will skip one file
|
||||
connector.size_threshold = 16384
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
assert len(retrieved_docs) == 50
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
@@ -94,6 +121,7 @@ def test_include_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
@@ -108,7 +136,10 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
|
||||
assert len(retrieved_docs) == 51
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -142,7 +173,7 @@ def test_include_my_drives_only(
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -176,7 +207,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -214,7 +245,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -257,7 +288,7 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -288,7 +319,7 @@ def test_specific_emails(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -318,7 +349,7 @@ def get_specific_folders_in_my_drive(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
@@ -50,7 +50,7 @@ def test_all(
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -83,7 +83,7 @@ def test_shared_drives_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -114,7 +114,7 @@ def test_shared_with_me_only(
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -142,7 +142,7 @@ def test_my_drive_only(
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -172,7 +172,7 @@ def test_shared_my_drive_folder(
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -199,7 +199,7 @@ def test_shared_drive_folder(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,14 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_zendesk_data.json") -> dict[str, dict]:
|
||||
@@ -50,7 +52,7 @@ def get_credentials() -> dict[str, str]:
|
||||
def test_zendesk_connector_basic(
|
||||
request: pytest.FixtureRequest, connector_fixture: str
|
||||
) -> None:
|
||||
connector = request.getfixturevalue(connector_fixture)
|
||||
connector = cast(ZendeskConnector, request.getfixturevalue(connector_fixture))
|
||||
test_data = load_test_data()
|
||||
all_docs: list[Document] = []
|
||||
target_test_doc_id: str
|
||||
@@ -61,12 +63,11 @@ def test_zendesk_connector_basic(
|
||||
|
||||
target_doc: Document | None = None
|
||||
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
for doc in load_all_docs_from_checkpoint_connector(connector, 0, time.time()):
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
|
||||
assert len(all_docs) > 0, "No documents were retrieved from the connector"
|
||||
assert (
|
||||
@@ -111,8 +112,10 @@ def test_zendesk_connector_basic(
|
||||
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
|
||||
# Get full doc IDs
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in zendesk_article_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
for doc in load_all_docs_from_checkpoint_connector(
|
||||
zendesk_article_connector, 0, time.time()
|
||||
):
|
||||
all_full_doc_ids.add(doc.id)
|
||||
|
||||
# Get slim doc IDs
|
||||
all_slim_doc_ids = set()
|
||||
|
||||
@@ -50,7 +50,7 @@ def answer_instance(
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
@@ -400,7 +400,7 @@ def test_no_slow_reranking(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import call
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskClient
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zendesk_client() -> MagicMock:
|
||||
"""Create a mock Zendesk client"""
|
||||
mock = MagicMock(spec=ZendeskClient)
|
||||
mock.base_url = "https://test.zendesk.com/api/v2"
|
||||
mock.auth = ("test@example.com/token", "test_token")
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zendesk_connector(
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with mocked client"""
|
||||
connector = ZendeskConnector(content_type="articles")
|
||||
connector.client = mock_zendesk_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unmocked_zendesk_connector() -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with unmocked client"""
|
||||
zendesk_connector = ZendeskConnector(content_type="articles")
|
||||
zendesk_connector.client = ZendeskClient(
|
||||
"test", "test@example.com/token", "test_token"
|
||||
)
|
||||
yield zendesk_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_article() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_article(
|
||||
id: int = 1,
|
||||
title: str = "Test Article",
|
||||
body: str = "Test Content",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
author_id: str = "123",
|
||||
label_names: list[str] | None = None,
|
||||
draft: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock article"""
|
||||
return {
|
||||
"id": id,
|
||||
"title": title,
|
||||
"body": body,
|
||||
"updated_at": updated_at,
|
||||
"author_id": author_id,
|
||||
"label_names": label_names or [],
|
||||
"draft": draft,
|
||||
"html_url": f"https://test.zendesk.com/hc/en-us/articles/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_article
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_ticket() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_ticket(
|
||||
id: int = 1,
|
||||
subject: str = "Test Ticket",
|
||||
description: str = "Test Description",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
submitter_id: str = "123",
|
||||
status: str = "open",
|
||||
priority: str = "normal",
|
||||
tags: list[str] | None = None,
|
||||
ticket_type: str = "question",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock ticket"""
|
||||
return {
|
||||
"id": id,
|
||||
"subject": subject,
|
||||
"description": description,
|
||||
"updated_at": updated_at,
|
||||
"submitter": submitter_id,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"tags": tags or [],
|
||||
"type": ticket_type,
|
||||
"url": f"https://test.zendesk.com/agent/tickets/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_ticket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_author() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_author(
|
||||
id: str = "123",
|
||||
name: str = "Test User",
|
||||
email: str = "test@example.com",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock author"""
|
||||
return {
|
||||
"user": {
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
}
|
||||
}
|
||||
|
||||
return _create_mock_author
|
||||
|
||||
|
||||
def test_load_from_checkpoint_articles_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading articles from checkpoint - happy path"""
|
||||
# Set up mock responses
|
||||
mock_article1 = create_mock_article(id=1, title="Article 1")
|
||||
mock_article2 = create_mock_article(id=2, title="Article 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page
|
||||
{
|
||||
"articles": [mock_article1, mock_article2],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "article:1"
|
||||
assert doc1.semantic_identifier == "Article 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "article:2"
|
||||
assert doc2.semantic_identifier == "Article 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_tickets_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading tickets from checkpoint - happy path"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses
|
||||
mock_ticket1 = create_mock_ticket(id=1, subject="Ticket 1")
|
||||
mock_ticket2 = create_mock_ticket(id=2, subject="Ticket 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page
|
||||
{
|
||||
"tickets": [mock_ticket1, mock_ticket2],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
# Fourth call: comments page
|
||||
{"comments": []},
|
||||
# Fifth call: comments page
|
||||
{"comments": []},
|
||||
]
|
||||
|
||||
zendesk_connector.client = mock_zendesk_client
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
print(doc1, type(doc1))
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "zendesk_ticket_1"
|
||||
assert doc1.semantic_identifier == "Ticket #1: Ticket 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "zendesk_ticket_2"
|
||||
assert doc2.semantic_identifier == "Ticket #2: Ticket 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
unmocked_zendesk_connector: ZendeskConnector,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
zendesk_connector = unmocked_zendesk_connector
|
||||
# Set up mock responses
|
||||
mock_article = create_mock_article()
|
||||
mock_author = create_mock_author()
|
||||
author_response = MagicMock()
|
||||
author_response.status_code = 200
|
||||
author_response.json.return_value = mock_author
|
||||
|
||||
# Create mock responses for requests.get
|
||||
rate_limit_response = MagicMock()
|
||||
rate_limit_response.status_code = 429
|
||||
rate_limit_response.headers = {"Retry-After": "60"}
|
||||
rate_limit_response.raise_for_status.side_effect = HTTPError(
|
||||
response=rate_limit_response
|
||||
)
|
||||
|
||||
success_response = MagicMock()
|
||||
success_response.status_code = 200
|
||||
success_response.json.return_value = {
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Mock requests.get to simulate rate limit then success
|
||||
with patch("onyx.connectors.zendesk.connector.requests.get") as mock_get:
|
||||
mock_get.side_effect = [
|
||||
# First call: content tags
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"records": [], "meta": {"has_more": False}},
|
||||
),
|
||||
# Second call: articles page (rate limited)
|
||||
rate_limit_response,
|
||||
# Third call: articles page (after rate limit)
|
||||
success_response,
|
||||
# Fourth call: author info
|
||||
author_response,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch("onyx.connectors.zendesk.connector.time.sleep") as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
zendesk_connector, 0, end_time
|
||||
)
|
||||
mock_sleep.assert_has_calls([call(60), call(0.1)])
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "article:1"
|
||||
|
||||
# Verify the requests were made with correct parameters
|
||||
assert mock_get.call_count == 4
|
||||
# First call should be for content tags
|
||||
args, kwargs = mock_get.call_args_list[0]
|
||||
assert "guide/content_tags" in args[0]
|
||||
# Second call should be for articles (rate limited)
|
||||
args, kwargs = mock_get.call_args_list[1]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Third call should be for articles (success)
|
||||
args, kwargs = mock_get.call_args_list[2]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Fourth call should be for author info
|
||||
args, kwargs = mock_get.call_args_list[3]
|
||||
assert "users/123" in args[0]
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_response(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with empty response"""
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: empty articles page
|
||||
{
|
||||
"articles": [],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_article(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an article that should be skipped"""
|
||||
# Set up mock responses with a draft article
|
||||
mock_article = create_mock_article(draft=True)
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page with draft article
|
||||
{
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_ticket(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with a deleted ticket"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses with a deleted ticket
|
||||
mock_ticket = create_mock_ticket(status="deleted")
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page with deleted ticket
|
||||
{
|
||||
"tickets": [mock_ticket],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"Your Zendesk credentials appear to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your Zendesk token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"Zendesk resource not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
error = HTTPError(response=mock_response)
|
||||
|
||||
mock_zendesk_client = cast(MagicMock, zendesk_connector.client)
|
||||
mock_zendesk_client.make_request.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
print("excinfo", excinfo)
|
||||
zendesk_connector.validate_connector_settings()
|
||||
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Mock successful API response
|
||||
mock_zendesk_client.make_request.return_value = {
|
||||
"articles": [],
|
||||
"meta": {"has_more": False},
|
||||
}
|
||||
|
||||
zendesk_connector.validate_connector_settings()
|
||||
@@ -254,14 +254,14 @@ export function SlackChannelConfigFormFields({
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
allowCustomValues={true}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<p className="mt-2 text-sm dark:text-neutral-400 text-neutral-600">
|
||||
Note: This list shows public and private channels where the
|
||||
bot is a member (up to 500 channels). If you don't see a
|
||||
channel, make sure the bot is added to that channel in Slack
|
||||
first, or type the channel name manually.
|
||||
Note: This list shows existing public and private channels (up
|
||||
to 500). You can either select from the list or type any
|
||||
channel name directly.
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -8,10 +7,14 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
|
||||
import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import { Button as TremorButton } from "@/components/ui/button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Credential,
|
||||
GoogleDriveCredentialJson,
|
||||
@@ -20,6 +23,15 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
@@ -31,126 +43,202 @@ export const DriveJsonUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 " +
|
||||
"cursor-pointer bg-backgrournd dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -160,6 +248,7 @@ interface DriveJsonUploadSectionProps {
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const DriveJsonUploadSection = ({
|
||||
@@ -168,6 +257,7 @@ export const DriveJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
existingAuthCredential,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
@@ -177,6 +267,7 @@ export const DriveJsonUploadSection = ({
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountCredentialData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -190,153 +281,135 @@ export const DriveJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/app-credential"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the google drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
</p>
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Google Drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Google Drive, create credentials (either OAuth App or
|
||||
Service Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<a
|
||||
className="text-link"
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/google_drive#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if chooosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/google-drive/service-account-key"
|
||||
: "/api/manage/admin/connector/google-drive/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/credentials"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && <DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -391,6 +464,7 @@ export const DriveAuthSection = ({
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountKeyData
|
||||
);
|
||||
@@ -405,6 +479,7 @@ export const DriveAuthSection = ({
|
||||
setLocalGoogleDriveServiceAccountCredential,
|
||||
] = useState(googleDriveServiceAccountCredential);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountKeyData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -424,126 +499,181 @@ export const DriveAuthSection = ({
|
||||
localGoogleDriveServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Google Drive credentials have been successfully uploaded
|
||||
and authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Google Drive Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string().required(
|
||||
"User email is required"
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<TremorButton type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the docs you have access to in your google drive account.
|
||||
</p>
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Google Drive via OAuth. This
|
||||
gives us read access to the documents you have access to in your
|
||||
Google Drive account.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
if (authUrl) {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
router.push(authUrl);
|
||||
return;
|
||||
}
|
||||
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
router.push(authUrl);
|
||||
} else {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to authenticate with Google Drive - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Authenticate with Google Drive
|
||||
{isAuthenticating
|
||||
? "Authenticating..."
|
||||
: "Authenticate with Google Drive"}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Google Drive
|
||||
Service Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -165,6 +165,10 @@ const GDriveMain = ({
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
googleDrivePublicUploadedCredential ||
|
||||
googleDriveServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin &&
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -8,7 +8,11 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGmailOAuth } from "@/lib/gmail";
|
||||
import { GMAIL_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
@@ -20,10 +24,19 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
const DriveJsonUpload = ({
|
||||
const GmailCredentialUpload = ({
|
||||
setPopup,
|
||||
onSuccess,
|
||||
}: {
|
||||
@@ -31,134 +44,210 @@ const DriveJsonUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 overflow-visible " +
|
||||
"cursor-pointer bg-background dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
// check if the JSON is a app credential or a service account credential
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface DriveJsonUploadSectionProps {
|
||||
interface GmailJsonUploadSectionProps {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
appCredentialData?: { client_id: string };
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const GmailJsonUploadSection = ({
|
||||
@@ -167,7 +256,8 @@ export const GmailJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
existingAuthCredential,
|
||||
}: GmailJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -190,156 +280,138 @@ export const GmailJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-key"
|
||||
);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
</p>
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Gmail, create credentials (either OAuth App or Service
|
||||
Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<a
|
||||
className="text-link"
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/gmail#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a Google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if choosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/gmail/service-account-key"
|
||||
: "/api/manage/admin/connector/gmail/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate("/api/manage/admin/connector/gmail/credentials");
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && (
|
||||
<GmailCredentialUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface DriveCredentialSectionProps {
|
||||
interface GmailCredentialSectionProps {
|
||||
gmailPublicCredential?: Credential<GmailCredentialJson>;
|
||||
gmailServiceAccountCredential?: Credential<GmailServiceAccountCredentialJson>;
|
||||
serviceAccountKeyData?: { service_account_email: string };
|
||||
@@ -387,7 +459,7 @@ export const GmailAuthSection = ({
|
||||
refreshCredentials,
|
||||
connectorExists,
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
}: GmailCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -420,104 +492,141 @@ export const GmailAuthSection = ({
|
||||
localGmailPublicCredential || localGmailServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Gmail credentials have been successfully uploaded and
|
||||
authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Gmail Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Gmail via OAuth. This gives us
|
||||
read access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
@@ -545,7 +654,6 @@ export const GmailAuthSection = ({
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
disabled={isAuthenticating}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
</Button>
|
||||
@@ -553,11 +661,6 @@ export const GmailAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Gmail Service
|
||||
Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -173,6 +173,9 @@ export const GmailMain = () => {
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
gmailPublicUploadedCredential || gmailServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin && hasUploadedCredentials && (
|
||||
|
||||
@@ -54,6 +54,7 @@ export const SourceCard: React.FC<{
|
||||
|
||||
<div className="flex items-center gap-1 mt-1">
|
||||
<ResultIcon doc={document} size={18} />
|
||||
|
||||
<div className="text-text-700 text-xs leading-tight truncate flex-1 min-w-0">
|
||||
{truncatedIdentifier}
|
||||
</div>
|
||||
|
||||
@@ -54,6 +54,7 @@ export function SearchMultiSelectDropdown({
|
||||
onDelete,
|
||||
onSearchTermChange,
|
||||
initialSearchTerm = "",
|
||||
allowCustomValues = false,
|
||||
}: {
|
||||
options: StringOrNumberOption[];
|
||||
onSelect: (selected: StringOrNumberOption) => void;
|
||||
@@ -62,6 +63,7 @@ export function SearchMultiSelectDropdown({
|
||||
onDelete?: (name: string) => void;
|
||||
onSearchTermChange?: (term: string) => void;
|
||||
initialSearchTerm?: string;
|
||||
allowCustomValues?: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState(initialSearchTerm);
|
||||
@@ -77,12 +79,29 @@ export function SearchMultiSelectDropdown({
|
||||
option.name.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
|
||||
// Handle selecting a custom value not in the options list
|
||||
const handleCustomValueSelect = () => {
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
const customOption: StringOrNumberOption = {
|
||||
name: searchTerm,
|
||||
value: searchTerm,
|
||||
};
|
||||
onSelect(customOption);
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
// If allowCustomValues is enabled and there's text in the search field,
|
||||
// treat clicking outside as selecting the custom value
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
@@ -91,7 +110,7 @@ export function SearchMultiSelectDropdown({
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
}, [allowCustomValues, searchTerm]);
|
||||
|
||||
useEffect(() => {
|
||||
setSearchTerm(initialSearchTerm);
|
||||
@@ -102,17 +121,33 @@ export function SearchMultiSelectDropdown({
|
||||
<div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
placeholder={
|
||||
allowCustomValues ? "Search or enter custom value..." : "Search..."
|
||||
}
|
||||
value={searchTerm}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
if (e.target.value) {
|
||||
const newValue = e.target.value;
|
||||
setSearchTerm(newValue);
|
||||
if (onSearchTermChange) {
|
||||
onSearchTermChange(newValue);
|
||||
}
|
||||
if (newValue) {
|
||||
setIsOpen(true);
|
||||
} else {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}}
|
||||
onFocus={() => setIsOpen(true)}
|
||||
onKeyDown={(e) => {
|
||||
if (
|
||||
e.key === "Enter" &&
|
||||
allowCustomValues &&
|
||||
searchTerm.trim() !== ""
|
||||
) {
|
||||
e.preventDefault();
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
}}
|
||||
className="inline-flex justify-between w-full px-4 py-2 text-sm bg-white dark:bg-transparent text-text-800 border border-background-300 rounded-md shadow-sm"
|
||||
/>
|
||||
<button
|
||||
@@ -153,6 +188,22 @@ export function SearchMultiSelectDropdown({
|
||||
)
|
||||
)}
|
||||
|
||||
{allowCustomValues &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<button
|
||||
className="w-full text-left flex items-center px-4 py-2 text-sm text-text-800 hover:bg-background-100"
|
||||
role="menuitem"
|
||||
onClick={handleCustomValueSelect}
|
||||
>
|
||||
<PlusIcon className="w-4 h-4 mr-2 text-text-600" />
|
||||
Use "{searchTerm}" as custom value
|
||||
</button>
|
||||
)}
|
||||
|
||||
{onCreate &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
@@ -177,7 +228,8 @@ export function SearchMultiSelectDropdown({
|
||||
)}
|
||||
|
||||
{filteredOptions.length === 0 &&
|
||||
(!onCreate || searchTerm.trim() === "") && (
|
||||
((!onCreate && !allowCustomValues) ||
|
||||
searchTerm.trim() === "") && (
|
||||
<div className="px-4 py-2.5 text-sm text-text-500">
|
||||
No matches found
|
||||
</div>
|
||||
|
||||
@@ -49,7 +49,7 @@ export function SearchResultIcon({ url }: { url: string }) {
|
||||
if (!faviconUrl) {
|
||||
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
|
||||
}
|
||||
if (url.includes("docs.onyx.app")) {
|
||||
if (url.includes("onyx.app")) {
|
||||
return <OnyxIcon size={18} className="dark:text-[#fff] text-[#000]" />;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,12 +17,11 @@ export function WebResultIcon({
|
||||
try {
|
||||
hostname = new URL(url).hostname;
|
||||
} catch (e) {
|
||||
// console.log(e);
|
||||
hostname = "docs.onyx.app";
|
||||
hostname = "onyx.app";
|
||||
}
|
||||
return (
|
||||
<>
|
||||
{hostname == "docs.onyx.app" ? (
|
||||
{hostname.includes("onyx.app") ? (
|
||||
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
|
||||
) : !error ? (
|
||||
<img
|
||||
|
||||
@@ -26,35 +26,6 @@ export const ResultIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
// export default function SourceCard({
|
||||
// doc,
|
||||
// setPresentingDocument,
|
||||
// }: {
|
||||
// doc: OnyxDocument;
|
||||
// setPresentingDocument?: (document: OnyxDocument) => void;
|
||||
// }) {
|
||||
// return (
|
||||
// <div
|
||||
// key={doc.document_id}
|
||||
// onClick={() => openDocument(doc, setPresentingDocument)}
|
||||
// className="cursor-pointer h-[80px] text-left overflow-hidden flex flex-col gap-0.5 rounded-lg px-3 py-2 bg-accent-background hover:bg-accent-background-hovered w-[200px]"
|
||||
// >
|
||||
// <div className="line-clamp-1 font-semibold text-ellipsis text-text-900 flex h-6 items-center gap-2 text-sm">
|
||||
// {doc.is_internet || doc.source_type === "web" ? (
|
||||
// <WebResultIcon url={doc.link} />
|
||||
// ) : (
|
||||
// <SourceIcon sourceType={doc.source_type} iconSize={18} />
|
||||
// )}
|
||||
// <p>{truncateString(doc.semantic_identifier || doc.document_id, 20)}</p>
|
||||
// </div>
|
||||
// <div className="line-clamp-2 text-sm font-semibold"></div>
|
||||
// <div className="line-clamp-2 text-sm font-normal leading-snug text-text-700">
|
||||
// {doc.blurb}
|
||||
// </div>
|
||||
// </div>
|
||||
// );
|
||||
// }
|
||||
|
||||
interface SeeMoreBlockProps {
|
||||
toggleDocumentSelection: () => void;
|
||||
docs: OnyxDocument[];
|
||||
|
||||
Reference in New Issue
Block a user