mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
43 Commits
debug_logg
...
fix_sessio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8e71216607 | ||
|
|
a123661c92 | ||
|
|
c554889baf | ||
|
|
f08fa878a6 | ||
|
|
d307534781 | ||
|
|
6f54791910 | ||
|
|
0d5497bb6b | ||
|
|
7648627503 | ||
|
|
927554d5ca | ||
|
|
7dcec6caf5 | ||
|
|
036648146d | ||
|
|
2aa4697ac8 | ||
|
|
bc9b4e4f45 | ||
|
|
178a64f298 | ||
|
|
c79f1edf1d | ||
|
|
7c8e23aa54 | ||
|
|
d37b427d52 | ||
|
|
a65fefd226 | ||
|
|
bb09bde519 | ||
|
|
0f6cf0fc58 | ||
|
|
fed06b592d | ||
|
|
8d92a1524e | ||
|
|
ecfea9f5ed | ||
|
|
b269f1ba06 | ||
|
|
30c878efa5 | ||
|
|
2024776c19 | ||
|
|
431316929c | ||
|
|
c5b9c6e308 | ||
|
|
73dd188b3f | ||
|
|
79b061abbc | ||
|
|
552f1ead4f | ||
|
|
17925b49e8 | ||
|
|
55fb5c3ca5 | ||
|
|
99546e4a4d | ||
|
|
c25d56f4a5 | ||
|
|
35f3f4f120 | ||
|
|
25b69a8aca | ||
|
|
1b7d710b2a | ||
|
|
ae3d3db3f4 | ||
|
|
fb79a9e700 | ||
|
|
587ba11bbc | ||
|
|
fce81ebb60 | ||
|
|
61facfb0a8 |
@@ -9,6 +9,10 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
@@ -45,6 +49,8 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -84,7 +84,7 @@ keys = console
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
level = INFO
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
# hidden! (defaults to level=WARN)
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -64,7 +68,7 @@ def include_object(
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def get_schema_options() -> tuple[str, bool, bool, bool]:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
# continue on error with individual tenant
|
||||
# only applies to online migrations
|
||||
continue_on_error = x_args.get("continue", "false").lower() == "true"
|
||||
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"Please specify a tenant-specific schema."
|
||||
)
|
||||
|
||||
return schema_name, create_schema, upgrade_all_tenants
|
||||
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
|
||||
|
||||
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
(
|
||||
schema_name,
|
||||
create_schema,
|
||||
upgrade_all_tenants,
|
||||
continue_on_error,
|
||||
) = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
i_tenant = 0
|
||||
num_tenants = len(tenant_schemas)
|
||||
for schema in tenant_schemas:
|
||||
i_tenant += 1
|
||||
logger.info(
|
||||
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
if not continue_on_error:
|
||||
logger.error("--continue is not set, raising exception!")
|
||||
raise
|
||||
|
||||
logger.warning("--continue is set, continuing to next schema.")
|
||||
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
"""This doesn't really get used when we migrate in the cloud."""
|
||||
|
||||
logger.info("run_migrations_offline starting.")
|
||||
|
||||
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,20 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
|
||||
@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Generator
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
@@ -66,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
logger.notice("Creating user from SAML login")
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -87,11 +87,14 @@ async def get_or_provision_tenant(
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
@@ -116,10 +119,6 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
@@ -561,7 +560,3 @@ async def assign_tenant_to_user(
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
|
||||
@@ -16,10 +16,10 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -62,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
width: 140px;
|
||||
height: auto;
|
||||
filter: brightness(1.1) contrast(1.2);
|
||||
border-radius: 8px;
|
||||
padding: 5px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
@@ -78,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
padding: 14px 24px;
|
||||
background-color: #0055FF;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
margin-top: 10px;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
text-align: center;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
@@ -166,6 +175,7 @@ def send_email(
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
@@ -174,17 +184,30 @@ def send_email(
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
# Add text part first (lowest priority)
|
||||
text_part = MIMEText(text_body, "plain")
|
||||
msg.attach(text_part)
|
||||
|
||||
if inline_png:
|
||||
# For HTML with images, create a multipart/related container
|
||||
related = MIMEMultipart("related")
|
||||
|
||||
# Add the HTML part to the related container
|
||||
html_part = MIMEText(html_body, "html")
|
||||
related.attach(html_part)
|
||||
|
||||
# Add image with proper Content-ID to the related container
|
||||
img = MIMEImage(inline_png[1], _subtype="png")
|
||||
img.add_header("Content-ID", inline_png[0]) # CID reference
|
||||
img.add_header("Content-ID", f"<{inline_png[0]}>")
|
||||
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
|
||||
msg.attach(img)
|
||||
related.attach(img)
|
||||
|
||||
# Add the related part to the message (higher priority than text)
|
||||
msg.attach(related)
|
||||
else:
|
||||
# No images, just add HTML directly (higher priority than text)
|
||||
html_part = MIMEText(html_body, "html")
|
||||
msg.attach(html_part)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -332,17 +355,23 @@ def send_forgot_password_email(
|
||||
|
||||
onyx_file = OnyxRuntime.get_emailable_logo()
|
||||
|
||||
subject = f"{application_name} Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
subject = f"Reset Your {application_name} Password"
|
||||
heading = "Reset Your Password"
|
||||
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
|
||||
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
|
||||
cta_text = "Reset Password"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
html_content = build_html_email(
|
||||
application_name,
|
||||
"Reset Your Password",
|
||||
heading,
|
||||
message,
|
||||
cta_text,
|
||||
cta_link,
|
||||
)
|
||||
text_content = (
|
||||
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
|
||||
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
|
||||
)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(
|
||||
user_email,
|
||||
subject,
|
||||
@@ -356,6 +385,7 @@ def send_forgot_password_email(
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
new_organization: bool = False,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
@@ -372,6 +402,8 @@ def send_user_verification_email(
|
||||
|
||||
subject = f"{application_name} Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
if new_organization:
|
||||
link = add_url_params(link, {"first_user": "true"})
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
|
||||
211
backend/onyx/auth/oauth_refresher.py
Normal file
211
backend/onyx/auth/oauth_refresher.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Standard OAuth refresh token endpoints
|
||||
REFRESH_ENDPOINTS = {
|
||||
"google": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Keeping this as a utility function for potential future debugging,
|
||||
# but not using it in production code
|
||||
async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Utility function for testing - Sets an OAuth token to expire in a short time
|
||||
to facilitate testing of the refresh flow.
|
||||
Not used in production code.
|
||||
"""
|
||||
try:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error setting artificial expiration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not oauth_account.refresh_token:
|
||||
logger.warning(
|
||||
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
|
||||
)
|
||||
return False
|
||||
|
||||
provider = oauth_account.oauth_name
|
||||
if provider not in REFRESH_ENDPOINTS:
|
||||
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
REFRESH_ENDPOINTS[provider],
|
||||
data={
|
||||
"client_id": OAUTH_CLIENT_ID,
|
||||
"client_secret": OAUTH_CLIENT_SECRET,
|
||||
"refresh_token": oauth_account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to refresh OAuth token: Status {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
new_access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get(
|
||||
"refresh_token", oauth_account.refresh_token
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
# Calculate new expiry time if provided
|
||||
new_expires_at: Optional[int] = None
|
||||
if expires_in:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expires_in)
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
|
||||
if new_expires_at:
|
||||
updated_data["expires_at"] = new_expires_at
|
||||
|
||||
# Update oidc_expiry in user model if we're tracking it
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(
|
||||
new_expires_at, tz=timezone.utc
|
||||
)
|
||||
await user_manager.user_db.update(
|
||||
user, {"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error refreshing OAuth token: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return
|
||||
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Buffer time to refresh tokens before they expire (in seconds)
|
||||
buffer_seconds = 300 # 5 minutes
|
||||
|
||||
for oauth_account in user.oauth_accounts:
|
||||
# Skip accounts without refresh tokens
|
||||
if not oauth_account.refresh_token:
|
||||
continue
|
||||
|
||||
# If token is about to expire, refresh it
|
||||
if (
|
||||
oauth_account.expires_at
|
||||
and oauth_account.expires_at - now_timestamp < buffer_seconds
|
||||
):
|
||||
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
|
||||
success = await refresh_oauth_token(
|
||||
user, oauth_account, db_session, user_manager
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to refresh OAuth token. User may need to re-authenticate."
|
||||
)
|
||||
|
||||
|
||||
async def check_oauth_account_has_refresh_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an OAuth account has a refresh token.
|
||||
Returns True if a refresh token exists, False otherwise.
|
||||
"""
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return []
|
||||
|
||||
accounts_needing_refresh = []
|
||||
for oauth_account in user.oauth_accounts:
|
||||
has_refresh_token = await check_oauth_account_has_refresh_token(
|
||||
user, oauth_account
|
||||
)
|
||||
if not has_refresh_token:
|
||||
accounts_needing_refresh.append(oauth_account)
|
||||
|
||||
return accounts_needing_refresh
|
||||
@@ -5,12 +5,16 @@ import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -581,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
logger.notice(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}"
|
||||
)
|
||||
|
||||
send_user_verification_email(user.email, token)
|
||||
user_count = await get_user_count()
|
||||
send_user_verification_email(
|
||||
user.email, token, new_organization=user_count == 1
|
||||
)
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
@@ -688,16 +694,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
@@ -756,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by extending its expiration time in Redis."""
|
||||
if token is None:
|
||||
# If no token provided, create a new one
|
||||
return await self.write_token(user)
|
||||
|
||||
redis = await get_async_redis_connection()
|
||||
token_key = f"{self.key_prefix}{token}"
|
||||
|
||||
# Check if token exists
|
||||
token_data_str = await redis.get(token_key)
|
||||
if not token_data_str:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Token exists, extend its lifetime
|
||||
token_data = json.loads(token_data_str)
|
||||
await redis.set(
|
||||
token_key,
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_redis_strategy() -> TenantAwareRedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> RefreshableDatabaseStrategy:
|
||||
return RefreshableDatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
@@ -806,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
return router
|
||||
|
||||
def get_refresh_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for session token refreshing.
|
||||
"""
|
||||
# Import the oauth_refresher here to avoid circular imports
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
refresh_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_login_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
|
||||
)
|
||||
async def refresh(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
|
||||
get_user_manager
|
||||
),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> Response:
|
||||
try:
|
||||
user, token = user_token
|
||||
logger.info(f"Processing token refresh request for user {user.email}")
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
supports_refresh = hasattr(strategy, "refresh_token") and callable(
|
||||
getattr(strategy, "refresh_token")
|
||||
)
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
return await backend.transport.get_login_response(new_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing session token: {str(e)}")
|
||||
# Fallback to logout and login if refresh fails
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
|
||||
# Fallback: logout and login again
|
||||
logger.info(
|
||||
"Strategy doesn't support refresh - using logout/login flow"
|
||||
)
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Token refresh failed: {str(e)}",
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
@@ -1039,12 +1200,20 @@ def get_oauth_router(
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
|
||||
@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = setup_logger()
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_DIR = "/var/log/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
# we have a better implementation (backpressure, etc)
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
|
||||
@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -386,6 +389,8 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
credential_id_to_delete: int | None = None
|
||||
connector_id_to_delete: int | None = None
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -440,16 +445,35 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Store IDs before potentially expiring cc_pair
|
||||
connector_id_to_delete = cc_pair.connector_id
|
||||
credential_id_to_delete = cc_pair.credential_id
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
|
||||
# Flush to ensure document deletion happens before connector deletion
|
||||
db_session.flush()
|
||||
|
||||
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
|
||||
# related to the deleted DocumentByConnectorCredentialPair during commit
|
||||
db_session.expire(cc_pair)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
@@ -482,15 +506,15 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"connector={connector_id_to_delete} "
|
||||
f"credential={credential_id_to_delete} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
@@ -540,7 +564,7 @@ def validate_connector_deletion_fences(
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
queued_upsert_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
@@ -627,7 +651,7 @@ def validate_connector_deletion_fence(
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
if member_str in queued_upsert_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
@@ -17,6 +17,7 @@ from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
@@ -63,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -104,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -284,7 +289,7 @@ def try_creating_permissions_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
@@ -875,6 +880,21 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"id": payload.id if payload else None,
|
||||
"total_docs": initial if initial is not None else 0,
|
||||
"remaining_docs": remaining,
|
||||
"synced_docs": (initial - remaining) if initial is not None else 0,
|
||||
"is_complete": remaining == 0,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
|
||||
@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -401,7 +402,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
redis_client.set(
|
||||
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
|
||||
1,
|
||||
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
|
||||
)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
|
||||
@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -435,7 +438,7 @@ def _run_indexing(
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
@@ -570,6 +573,22 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_num": batch_num,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
@@ -585,6 +604,30 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Add telemetry for completed indexing
|
||||
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt_start.search_settings_id
|
||||
)
|
||||
final_progress = redis_connector_index.get_progress() or 0
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_count": batch_num,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
"redis_progress": final_progress,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
|
||||
@@ -73,6 +73,7 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
@@ -1069,6 +1070,8 @@ def stream_chat_message_objects(
|
||||
prev_message = next_answer_message
|
||||
|
||||
logger.debug("Committing messages")
|
||||
# Explicitly update the timestamp on the chat session
|
||||
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
@@ -157,10 +159,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
|
||||
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
|
||||
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
try:
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
|
||||
|
||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||
|
||||
@@ -386,10 +385,27 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
|
||||
def get_current_tz_offset() -> int:
|
||||
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
|
||||
# remove tzinfo to compare non-timezone-aware objects.
|
||||
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
# For the default value, we assume that the user's local timezone is more likely to be
|
||||
# correct (i.e. the configured user's timezone or the default server one) than UTC.
|
||||
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
@@ -676,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -382,6 +382,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
|
||||
@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
|
||||
@@ -65,19 +65,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
ONE_HOUR = 3600
|
||||
|
||||
|
||||
class ConfluenceConnector(
|
||||
@@ -207,7 +195,6 @@ class ConfluenceConnector(
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
@@ -372,11 +359,13 @@ class ConfluenceConnector(
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
):
|
||||
logger.info(f"Skipping attachment: {attachment['title']}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
|
||||
# Attempt to get textual content or image summarization:
|
||||
try:
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
@@ -429,7 +418,17 @@ class ConfluenceConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
try:
|
||||
return self._fetch_document_batches(start, end)
|
||||
except Exception as e:
|
||||
if "field 'updated' is invalid" in str(e) and start is not None:
|
||||
logger.warning(
|
||||
"Confluence says we provided an invalid 'updated' field. This may indicate"
|
||||
"a real issue, but can also appear during edge cases like daylight"
|
||||
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
|
||||
)
|
||||
return self._fetch_document_batches(start - ONE_HOUR, end)
|
||||
raise
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
|
||||
@@ -498,10 +498,12 @@ class OnyxConfluence:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
|
||||
"but we have logic to work around it - don't worry this isn't"
|
||||
f" causing an issue. Start: {new_start}, Previous Start: "
|
||||
f"{previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
|
||||
@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
@@ -69,7 +70,9 @@ def _process_egnyte_file(
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
|
||||
@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), subfile, metadata
|
||||
elif is_valid_file_ext(extension):
|
||||
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -122,7 +123,7 @@ def _process_file(
|
||||
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
|
||||
return []
|
||||
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
@@ -219,24 +220,34 @@ def _process_file(
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images.
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Extract text and images from the file
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
extraction_result = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
)
|
||||
|
||||
# Merge file-specific metadata (from file content) with provided metadata
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
|
||||
)
|
||||
metadata.update(extraction_result.metadata)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if text_content.strip():
|
||||
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
|
||||
if extraction_result.text_content.strip():
|
||||
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
|
||||
sections.append(
|
||||
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
|
||||
)
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
extraction_result.embedded_images, start=1
|
||||
):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image reference
|
||||
try:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ITEMS_PER_PAGE = 100
|
||||
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
|
||||
def _get_batch_rate_limited(
|
||||
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Any]:
|
||||
) -> list[PullRequest | Issue]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
|
||||
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
|
||||
)
|
||||
|
||||
|
||||
def _batch_github_objects(
|
||||
git_objs: PaginatedList, github_client: Github, batch_size: int
|
||||
) -> Iterator[list[Any]]:
|
||||
page_num = 0
|
||||
while True:
|
||||
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
|
||||
page_num += 1
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for mini_batch in batch_generator(batch, batch_size=batch_size):
|
||||
yield mini_batch
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
|
||||
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None,
|
||||
metadata={
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class GithubConnector(LoadConnector, PollConnector):
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
ISSUES = "issues"
|
||||
|
||||
|
||||
class GithubConnectorCheckpoint(ConnectorCheckpoint):
|
||||
stage: GithubConnectorStage
|
||||
curr_page: int
|
||||
|
||||
cached_repo_ids: list[int] | None = None
|
||||
cached_repo: SerializedRepository | None = None
|
||||
|
||||
|
||||
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
self.include_issues = include_issues
|
||||
self.github_client: Github | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# defaults to 30 items per page, can be set to as high as 100
|
||||
self.github_client = (
|
||||
Github(
|
||||
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
|
||||
credentials["github_access_token"],
|
||||
base_url=GITHUB_CONNECTOR_BASE_URL,
|
||||
per_page=ITEMS_PER_PAGE,
|
||||
)
|
||||
if GITHUB_CONNECTOR_BASE_URL
|
||||
else Github(credentials["github_access_token"])
|
||||
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
self,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=checkpoint.cached_repo_ids[0],
|
||||
headers=repos[0].raw_headers,
|
||||
raw_data=repos[0].raw_data,
|
||||
)
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
# save checkpoint with repo ids retrieved
|
||||
return checkpoint
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
|
||||
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
|
||||
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
pr_batch = _get_batch_rate_limited(
|
||||
pull_requests, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_prs = False
|
||||
for pr in pr_batch:
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
yield from doc_batch
|
||||
done_with_prs = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and pr.updated_at
|
||||
and pr.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
continue
|
||||
try:
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(pr.id), document_link=pr.html_url
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
# if we found any PRs on the page, yield any associated documents and return the checkpoint
|
||||
if not done_with_prs and len(pr_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# prs to get, we move on to issues
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.stage = GithubConnectorStage.ISSUES
|
||||
|
||||
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
doc_batch = []
|
||||
issue_batch = _get_batch_rate_limited(
|
||||
issues, checkpoint.curr_page, self.github_client
|
||||
)
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
for issue in cast(list[Issue], issue_batch):
|
||||
# we iterate backwards in time, so at this point we stop processing prs
|
||||
if (
|
||||
start is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) < start
|
||||
):
|
||||
yield from doc_batch
|
||||
done_with_issues = True
|
||||
break
|
||||
# Skip PRs updated after the end date
|
||||
if (
|
||||
end is not None
|
||||
and issue.updated_at.replace(tzinfo=timezone.utc) > end
|
||||
):
|
||||
continue
|
||||
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
|
||||
try:
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=str(issue.id),
|
||||
document_link=issue.html_url,
|
||||
),
|
||||
failure_message=error_msg,
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
# if we found any issues on the page, yield them and return the checkpoint
|
||||
if not done_with_issues and len(issue_batch) > 0:
|
||||
yield from doc_batch
|
||||
return checkpoint
|
||||
|
||||
# if we went past the start date during the loop or there are no more
|
||||
# issues to get, we move on to the next repo
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.curr_page = 0
|
||||
|
||||
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
|
||||
if checkpoint.cached_repo_ids:
|
||||
next_id = checkpoint.cached_repo_ids.pop()
|
||||
next_repo = self.github_client.get_repo(next_id)
|
||||
checkpoint.cached_repo = SerializedRepository(
|
||||
id=next_id,
|
||||
headers=next_repo.raw_headers,
|
||||
raw_data=next_repo.raw_data,
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
|
||||
# Could be due to delayed processing on GitHub side
|
||||
# The non-updated issues since last poll will be shortcut-ed and not embedded
|
||||
adjusted_start_datetime = start_datetime - timedelta(hours=3)
|
||||
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
if adjusted_start_datetime < epoch:
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
|
||||
return GithubConnectorCheckpoint(
|
||||
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -406,7 +544,9 @@ if __name__ == "__main__":
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
0, time.time(), connector.build_dummy_checkpoint()
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -2,11 +2,11 @@ import copy
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -15,6 +15,7 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -27,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
@@ -57,13 +60,13 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -85,12 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
|
||||
def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: dict[str, Any],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
# We used to always get the user email from the file owners when available,
|
||||
# but this was causing issues with shared folders where the owner was not included in the service account
|
||||
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
|
||||
# wanting to retry with file owners and/or admin email at some point.
|
||||
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
|
||||
user_email = retriever_email
|
||||
# Only construct these services when needed
|
||||
user_drive_service = lazy_eval(
|
||||
lambda: get_drive_service(creds, user_email=user_email)
|
||||
@@ -103,6 +112,7 @@ def _convert_single_file(
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
allow_images=allow_images,
|
||||
size_threshold=size_threshold,
|
||||
)
|
||||
|
||||
|
||||
@@ -238,6 +248,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
self._retrieved_ids: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@@ -445,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
|
||||
yield from add_retrieval_info(
|
||||
get_all_files_in_my_drive(
|
||||
get_all_files_in_my_drive_and_shared(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
is_slim=is_slim,
|
||||
include_shared_with_me=self.include_files_shared_with_me,
|
||||
start=curr_stage.completed_until if resuming else start,
|
||||
end=end,
|
||||
),
|
||||
@@ -456,6 +469,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
DriveRetrievalStage.MY_DRIVE_FILES,
|
||||
)
|
||||
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
|
||||
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
||||
|
||||
@@ -491,7 +505,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
yield from _yield_from_drive(drive_id, start)
|
||||
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
||||
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
|
||||
|
||||
def _yield_from_folder_crawl(
|
||||
@@ -544,6 +558,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
|
||||
)
|
||||
|
||||
# Setup initial completion map on first connector run
|
||||
for email in all_org_emails:
|
||||
# don't overwrite existing completion map on resuming runs
|
||||
if email in checkpoint.completion_map:
|
||||
continue
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
|
||||
# we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
@@ -557,11 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
for email in all_org_emails:
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
@@ -792,10 +811,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
return
|
||||
|
||||
for file in drive_files:
|
||||
if file.error is not None:
|
||||
if file.error is None:
|
||||
checkpoint.completion_map[file.user_email].update(
|
||||
stage=file.completion_stage,
|
||||
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
|
||||
completed_until=datetime.fromisoformat(
|
||||
file.drive_file[GoogleFields.MODIFIED_TIME.value]
|
||||
).timestamp(),
|
||||
completed_until_parent_id=file.parent_id,
|
||||
)
|
||||
yield file
|
||||
@@ -897,117 +918,86 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[list[Document | ConnectorFailure]]:
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
try:
|
||||
# Create a larger process pool for file conversion
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
file.user_email,
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
]
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
|
||||
# Process the batch
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if retrieved_file.error is not None:
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = (
|
||||
f"retrieval failure during stage: {failure_stage},"
|
||||
)
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += (
|
||||
f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
files_batch = []
|
||||
continue
|
||||
files_batch.append(retrieved_file)
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
error_str = f"Error converting file: {e}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=retrieved_file.drive_file["id"],
|
||||
document_link=retrieved_file.drive_file[
|
||||
"webViewLink"
|
||||
],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
yield from _yield_batch(files_batch)
|
||||
files_batch = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
|
||||
return # create a new checkpoint
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
yield from _yield_batch(files_batch)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
raise e
|
||||
@@ -1029,10 +1019,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
for doc_list in self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end
|
||||
):
|
||||
yield from doc_list
|
||||
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
@@ -1067,9 +1054,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
|
||||
@@ -76,7 +76,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _extract_sections_basic(
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
@@ -87,35 +87,17 @@ def _extract_sections_basic(
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
@@ -124,88 +106,100 @@ def _extract_sections_basic(
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {file_name}: {e}")
|
||||
return []
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
@@ -213,6 +207,7 @@ def convert_drive_item_to_document(
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
@@ -240,9 +235,24 @@ def convert_drive_item_to_document(
|
||||
f"Error in advanced parsing: {e}. Falling back to basic extraction."
|
||||
)
|
||||
|
||||
size_str = file.get("size")
|
||||
if size_str:
|
||||
try:
|
||||
size_int = int(size_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Parsing string to int failed: size_str={size_str}")
|
||||
else:
|
||||
if size_int > size_threshold:
|
||||
logger.warning(
|
||||
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service(), allow_images)
|
||||
sections = _download_and_extract_sections_basic(
|
||||
file, drive_service(), allow_images
|
||||
)
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
|
||||
@@ -123,7 +123,7 @@ def crawl_folders_for_files(
|
||||
end=end,
|
||||
):
|
||||
found_files = True
|
||||
logger.info(f"Found file: {file['name']}")
|
||||
logger.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool,
|
||||
include_shared_with_me: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
folder_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
|
||||
@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and file_extension in VALID_FILE_EXTENSIONS
|
||||
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
and can_download
|
||||
):
|
||||
# For documents, try to get the text content
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TypeAlias
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
@@ -231,7 +230,7 @@ class CheckpointConnector(BaseConnector[CT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
@abc.abstractmethod
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -438,7 +438,11 @@ def _get_all_doc_ids(
|
||||
|
||||
class ProcessedSlackMessage(BaseModel):
|
||||
doc: Document | None
|
||||
thread_ts: str | None
|
||||
# if the message is part of a thread, this is the thread_ts
|
||||
# otherwise, this is the message_ts. Either way, will be a unique identifier.
|
||||
# In the future, if the message becomes a thread, then the thread_ts
|
||||
# will be set to the message_ts.
|
||||
thread_or_message_ts: str
|
||||
failure: ConnectorFailure | None
|
||||
|
||||
|
||||
@@ -452,6 +456,7 @@ def _process_message(
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
@@ -467,16 +472,18 @@ def _process_message(
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None)
|
||||
return ProcessedSlackMessage(
|
||||
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return ProcessedSlackMessage(
|
||||
doc=None,
|
||||
thread_ts=thread_ts,
|
||||
thread_or_message_ts=thread_or_message_ts,
|
||||
failure=ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
channel_id=channel["id"], thread_ts=thread_or_message_ts
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
@@ -616,7 +623,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
for future in as_completed(futures):
|
||||
processed_slack_message = future.result()
|
||||
doc = processed_slack_message.doc
|
||||
thread_ts = processed_slack_message.thread_ts
|
||||
thread_or_message_ts = processed_slack_message.thread_or_message_ts
|
||||
failure = processed_slack_message.failure
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
@@ -624,11 +631,13 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
if thread_or_message_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
assert thread_ts, "found non-None doc with None thread_ts"
|
||||
seen_thread_ts.add(thread_ts)
|
||||
assert (
|
||||
thread_or_message_ts
|
||||
), "found non-None doc with None thread_or_message_ts"
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
|
||||
@@ -1,23 +1,32 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -26,6 +35,7 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@@ -53,10 +63,22 @@ class ZendeskClient:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
elif (
|
||||
response.status_code == 403
|
||||
and response.json().get("error") == "SupportProductInactive"
|
||||
):
|
||||
return response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
class ZendeskPageResponse(BaseModel):
|
||||
data: list[dict[str, Any]]
|
||||
meta: dict[str, Any]
|
||||
has_more: bool
|
||||
|
||||
|
||||
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
content_tags: dict[str, str] = {}
|
||||
params = {"page[size]": MAX_PAGE_SIZE}
|
||||
@@ -82,11 +104,9 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
|
||||
def _get_articles(
|
||||
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = (
|
||||
{"start_time": start_time, "page[size]": page_size}
|
||||
if start_time
|
||||
else {"page[size]": page_size}
|
||||
)
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
|
||||
while True:
|
||||
data = client.make_request("help_center/articles", params)
|
||||
@@ -98,10 +118,30 @@ def _get_articles(
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
|
||||
|
||||
def _get_article_page(
|
||||
client: ZendeskClient,
|
||||
start_time: int | None = None,
|
||||
after_cursor: str | None = None,
|
||||
page_size: int = MAX_PAGE_SIZE,
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
|
||||
if start_time is not None:
|
||||
params["start_time"] = start_time
|
||||
if after_cursor is not None:
|
||||
params["page[after]"] = after_cursor
|
||||
|
||||
data = client.make_request("help_center/articles", params)
|
||||
return ZendeskPageResponse(
|
||||
data=data["articles"],
|
||||
meta=data["meta"],
|
||||
has_more=bool(data["meta"].get("has_more", False)),
|
||||
)
|
||||
|
||||
|
||||
def _get_tickets(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
params = {"start_time": start_time} if start_time else {"start_time": 0}
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
while True:
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
@@ -114,9 +154,33 @@ def _get_tickets(
|
||||
break
|
||||
|
||||
|
||||
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
|
||||
# TODO: maybe these don't need to be their own functions?
|
||||
def _get_tickets_page(
|
||||
client: ZendeskClient, start_time: int | None = None
|
||||
) -> ZendeskPageResponse:
|
||||
params = {"start_time": start_time or 0}
|
||||
|
||||
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
|
||||
# in my local testing with very few tickets. We'll look into it if this becomes an
|
||||
# issue in larger deployments
|
||||
data = client.make_request("incremental/tickets.json", params)
|
||||
if data.get("error") == "SupportProductInactive":
|
||||
raise ValueError(
|
||||
"Zendesk Support Product is not active for this account, No tickets to index"
|
||||
)
|
||||
return ZendeskPageResponse(
|
||||
data=data["tickets"],
|
||||
meta={"end_time": data["end_time"]},
|
||||
has_more=not bool(data.get("end_of_stream", False)),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_author(
|
||||
client: ZendeskClient, author_id: str | int
|
||||
) -> BasicExpertInfo | None:
|
||||
# Skip fetching if author_id is invalid
|
||||
if not author_id or author_id == "-1":
|
||||
# cast to str to avoid issues with zendesk changing their types
|
||||
if not author_id or str(author_id) == "-1":
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -278,13 +342,22 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# We use cursor-based paginated retrieval for articles
|
||||
after_cursor_articles: str | None
|
||||
|
||||
# We use timestamp-based paginated retrieval for tickets
|
||||
next_start_time_tickets: int | None
|
||||
|
||||
cached_author_map: dict[str, BasicExpertInfo] | None
|
||||
cached_content_tags: dict[str, str] | None
|
||||
|
||||
|
||||
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
content_type: str = "articles",
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.content_type = content_type
|
||||
self.subdomain = ""
|
||||
# Fetch all tags ahead of time
|
||||
@@ -304,33 +377,50 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self.poll_source(None, None)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
self.content_tags = _get_content_tag_mapping(self.client)
|
||||
if checkpoint.cached_content_tags is None:
|
||||
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
|
||||
return checkpoint # save the content tags to the checkpoint
|
||||
self.content_tags = checkpoint.cached_content_tags
|
||||
|
||||
if self.content_type == "articles":
|
||||
yield from self._poll_articles(start)
|
||||
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
|
||||
return checkpoint
|
||||
elif self.content_type == "tickets":
|
||||
yield from self._poll_tickets(start)
|
||||
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
|
||||
return checkpoint
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
|
||||
def _poll_articles(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
articles = _get_articles(self.client, start_time=int(start) if start else None)
|
||||
|
||||
def _retrieve_articles(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
# This one is built on the fly as there may be more many more authors than tags
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
after_cursor = checkpoint.after_cursor_articles
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
doc_batch = []
|
||||
response = _get_article_page(
|
||||
self.client,
|
||||
start_time=int(start) if start else None,
|
||||
after_cursor=after_cursor,
|
||||
)
|
||||
articles = response.data
|
||||
has_more = response.has_more
|
||||
after_cursor = response.meta.get("after_cursor")
|
||||
for article in articles:
|
||||
if (
|
||||
article.get("body") is None
|
||||
@@ -342,66 +432,109 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
continue
|
||||
|
||||
new_author_map, documents = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
try:
|
||||
new_author_map, document = _article_to_document(
|
||||
article, self.content_tags, author_map, self.client
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{article.get('id')}",
|
||||
document_link=article.get("html_url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
doc_batch.append(document)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
def _poll_tickets(
|
||||
self, start: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Sometimes no documents are retrieved, but the cursor
|
||||
# is still updated so the connector makes progress.
|
||||
yield from doc_batch
|
||||
checkpoint.after_cursor_articles = after_cursor
|
||||
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def _retrieve_tickets(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
checkpoint: ZendeskConnectorCheckpoint,
|
||||
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
author_map: dict[str, BasicExpertInfo] = {}
|
||||
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
|
||||
|
||||
ticket_generator = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
doc_batch: list[Document] = []
|
||||
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
|
||||
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
|
||||
tickets = ticket_response.data
|
||||
has_more = ticket_response.has_more
|
||||
next_start_time = ticket_response.meta["end_time"]
|
||||
for ticket in tickets:
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
try:
|
||||
new_author_map, document = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
except Exception as e:
|
||||
yield ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=f"{ticket.get('id')}",
|
||||
document_link=ticket.get("url", ""),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
continue
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(document)
|
||||
|
||||
if not has_more:
|
||||
yield from doc_batch
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
yield from doc_batch
|
||||
checkpoint.next_start_time_tickets = next_start_time
|
||||
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
|
||||
checkpoint.has_more = bool(
|
||||
end is None
|
||||
or last_doc_updated_at is None
|
||||
or last_doc_updated_at.timestamp() <= end
|
||||
)
|
||||
|
||||
while True:
|
||||
doc_batch = []
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
ticket = next(ticket_generator)
|
||||
|
||||
# Check if the ticket status is deleted and skip it if so
|
||||
if ticket.get("status") == "deleted":
|
||||
continue
|
||||
|
||||
new_author_map, documents = _ticket_to_document(
|
||||
ticket=ticket,
|
||||
author_map=author_map,
|
||||
client=self.client,
|
||||
default_subdomain=self.subdomain,
|
||||
)
|
||||
|
||||
if new_author_map:
|
||||
author_map.update(new_author_map)
|
||||
|
||||
doc_batch.append(documents)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
except StopIteration:
|
||||
# No more tickets to process
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
return
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
checkpoint.cached_author_map = (
|
||||
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -441,10 +574,51 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@override
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.client is None:
|
||||
raise ZendeskCredentialsNotSetUpError()
|
||||
|
||||
try:
|
||||
_get_article_page(self.client, start_time=0)
|
||||
except HTTPError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.response.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.response.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
elif e.response.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Zendesk resource not found (HTTP 404)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
@override
|
||||
def validate_checkpoint_json(
|
||||
self, checkpoint_json: str
|
||||
) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
@override
|
||||
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
|
||||
return ZendeskConnectorCheckpoint(
|
||||
after_cursor_articles=None,
|
||||
next_start_time_tickets=None,
|
||||
cached_author_map=None,
|
||||
cached_content_tags=None,
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
connector.load_credentials(
|
||||
@@ -457,6 +631,8 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
document_batches = connector.load_from_checkpoint(
|
||||
one_day_ago, current, connector.build_dummy_checkpoint()
|
||||
)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -1089,3 +1089,20 @@ def log_agent_sub_question_results(
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def update_chat_session_updated_at_timestamp(
|
||||
chat_session_id: UUID, db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Explicitly update the timestamp on a chat session without modifying other fields.
|
||||
This is useful when adding messages to a chat session to reflect recent activity.
|
||||
"""
|
||||
|
||||
# Direct SQL update to avoid loading the entire object if it's not already loaded
|
||||
db_session.execute(
|
||||
update(ChatSession)
|
||||
.where(ChatSession.id == chat_session_id)
|
||||
.values(time_updated=func.now())
|
||||
)
|
||||
# No commit - the caller is responsible for committing the transaction
|
||||
|
||||
@@ -555,6 +555,28 @@ def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> None:
|
||||
"""Deletes all document by connector credential pair entries for a specific connector and credential.
|
||||
This is primarily used during connector deletion to ensure all references are removed
|
||||
before deleting the connector itself. This is crucial because connector_id is part of the
|
||||
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
|
||||
would otherwise try to set the foreign key to NULL, which fails for primary keys.
|
||||
|
||||
NOTE: Does not commit the transaction, this must be done by the caller.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
@@ -8,23 +8,31 @@ from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.server.documents.models import ConnectorCredentialPair
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
|
||||
# from onyx.db.engine import async_query_for_dms
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
|
||||
attempt.status = IndexingStatus.IN_PROGRESS
|
||||
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt.id,
|
||||
"status": IndexingStatus.IN_PROGRESS.value,
|
||||
"cc_pair_id": index_attempt.connector_credential_pair_id,
|
||||
"search_settings_id": index_attempt.search_settings_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
|
||||
|
||||
attempt.status = IndexingStatus.SUCCESS
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.SUCCESS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
|
||||
|
||||
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
|
||||
attempt.status = IndexingStatus.CANCELED
|
||||
attempt.error_msg = reason
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.CANCELED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -283,6 +342,20 @@ def mark_attempt_failed(
|
||||
attempt.error_msg = failure_reason
|
||||
attempt.full_exception_trace = full_exception_trace
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEX_ATTEMPT_STATUS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": failure_reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
with get_session_context_manager() as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_latest_index_attempts(
|
||||
secondary_index,
|
||||
db_session,
|
||||
|
||||
@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
Assumed only admins can hit this endpoint.
|
||||
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
|
||||
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
return
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
|
||||
@@ -5,13 +5,15 @@ import re
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from email.parser import Parser as EmailParser
|
||||
from enum import auto
|
||||
from enum import IntFlag
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
@@ -35,7 +37,7 @@ logger = setup_logger()
|
||||
|
||||
TEXT_SECTION_SEPARATOR = "\n\n"
|
||||
|
||||
PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".yaml",
|
||||
]
|
||||
|
||||
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = (
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
)
|
||||
|
||||
IMAGE_MEDIA_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
Document = auto()
|
||||
Multimedia = auto()
|
||||
All = Plain | Document | Multimedia
|
||||
|
||||
|
||||
def is_text_file_extension(file_name: str) -> bool:
|
||||
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
|
||||
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
|
||||
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
|
||||
return media_type in IMAGE_MEDIA_TYPES
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
|
||||
if ext_type & OnyxExtensionType.Plain:
|
||||
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Document:
|
||||
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Multimedia:
|
||||
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_text_file(file: IO[bytes]) -> bool:
|
||||
@@ -219,7 +249,7 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
|
||||
def read_pdf_file(
|
||||
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
|
||||
) -> tuple[str, dict, list[tuple[bytes, str]]]:
|
||||
) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Returns the text, basic PDF metadata, and optionally extracted images.
|
||||
"""
|
||||
@@ -282,13 +312,13 @@ def read_pdf_file(
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any],
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: List[Tuple[bytes, str]] = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
doc = docx.Document(file)
|
||||
|
||||
@@ -382,6 +412,9 @@ def extract_file_text(
|
||||
"""
|
||||
Legacy function that returns *only text*, ignoring embedded images.
|
||||
For backward-compatibility in code that only wants text.
|
||||
|
||||
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
|
||||
handle (such as images).
|
||||
"""
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
@@ -405,7 +438,9 @@ def extract_file_text(
|
||||
if extension is None:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if is_valid_file_ext(extension):
|
||||
if is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
func = extension_to_function.get(extension, file_io_to_text)
|
||||
file.seek(0)
|
||||
return func(file)
|
||||
@@ -426,14 +461,22 @@ def extract_file_text(
|
||||
return ""
|
||||
|
||||
|
||||
class ExtractionResult(NamedTuple):
|
||||
"""Structured result from text and image extraction from various file types."""
|
||||
|
||||
text_content: str
|
||||
embedded_images: Sequence[tuple[bytes, str]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
def extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -442,7 +485,9 @@ def extract_text_and_images(
|
||||
# If the user doesn't want embedded images, unstructured is fine
|
||||
file.seek(0)
|
||||
text_content = unstructured_to_text(file, file_name)
|
||||
return (text_content, [])
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
@@ -450,54 +495,76 @@ def extract_text_and_images(
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
return (text_content, images)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
|
||||
# PDF example: we do not show complicated PDF image extraction here
|
||||
# so we simply extract text for now and skip images.
|
||||
if extension == ".pdf":
|
||||
file.seek(0)
|
||||
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
|
||||
return (text_content, images)
|
||||
text_content, pdf_metadata, images = read_pdf_file(
|
||||
file, pdf_pass, extract_images=True
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata=pdf_metadata
|
||||
)
|
||||
|
||||
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
|
||||
# You can do something similar to docx if needed.
|
||||
if extension == ".pptx":
|
||||
file.seek(0)
|
||||
return (pptx_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=pptx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".xlsx":
|
||||
file.seek(0)
|
||||
return (xlsx_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=xlsx_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".eml":
|
||||
file.seek(0)
|
||||
return (eml_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=eml_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".epub":
|
||||
file.seek(0)
|
||||
return (epub_to_text(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=epub_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".html":
|
||||
file.seek(0)
|
||||
return (parse_html_page_basic(file), [])
|
||||
return ExtractionResult(
|
||||
text_content=parse_html_page_basic(file),
|
||||
embedded_images=[],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
file.seek(0)
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, _ = read_text_file(
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return (text_content_raw, [])
|
||||
return ExtractionResult(
|
||||
text_content=text_content_raw,
|
||||
embedded_images=[],
|
||||
metadata=file_metadata,
|
||||
)
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
file.seek(0)
|
||||
return ("", [])
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
|
||||
return ("", [])
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
|
||||
@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
# Add refresh token endpoint for OAuth as well
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_refresh_router(auth_backend),
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
RequestValidationError, validation_exception_handler
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
|
||||
|
||||
# Delete the message with the option to mark resolved
|
||||
if not immediate:
|
||||
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
|
||||
response = slack_call(
|
||||
response = client.web_client.chat_delete(
|
||||
channel=channel_id,
|
||||
ts=message_ts,
|
||||
)
|
||||
|
||||
@@ -170,7 +170,8 @@ def handle_message(
|
||||
respond_tag_only = channel_conf.get("respond_tag_only") or False
|
||||
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
|
||||
|
||||
if respond_tag_only and not bypass_filters:
|
||||
# NOTE: always respond in the DMs, as long the default config is not disabled.
|
||||
if respond_tag_only and not bypass_filters and not is_bot_dm:
|
||||
logger.info(
|
||||
"Skipping message since the channel is configured such that "
|
||||
"OnyxBot only responds to tags"
|
||||
|
||||
@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RateLimitErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -944,10 +947,21 @@ def _get_socket_client(
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
# use the retry handlers built into the slack sdk
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
|
||||
slack_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
return TenantSocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
web_client=WebClient(
|
||||
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
slack_bot_id=slack_bot_id,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
|
||||
from onyx.configs.onyxbot_configs import (
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -125,13 +124,18 @@ def update_emote_react(
|
||||
)
|
||||
return
|
||||
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
if remove:
|
||||
client.reactions_remove(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
else:
|
||||
client.reactions_add(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
response = client.users_info(user=user_id)
|
||||
if not response["ok"]:
|
||||
return None
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
# just gets the version of Onyx (e.g. 0.3.11)
|
||||
("/version", {"GET"}),
|
||||
# stuff related to basic auth
|
||||
("/auth/refresh", {"POST"}),
|
||||
("/auth/register", {"POST"}),
|
||||
("/auth/login", {"POST"}),
|
||||
("/auth/logout", {"POST"}),
|
||||
|
||||
@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
|
||||
class UserRoleUpdateRequest(BaseModel):
|
||||
user_email: str
|
||||
new_role: UserRole
|
||||
explicit_override: bool = False
|
||||
|
||||
|
||||
class UserRoleResponse(BaseModel):
|
||||
|
||||
@@ -261,9 +261,6 @@ def create_bot(
|
||||
# Create a default Slack channel config
|
||||
default_channel_config = ChannelConfig(
|
||||
channel_name=None,
|
||||
respond_member_group_list=[],
|
||||
answer_filters=[],
|
||||
follow_up_tags=[],
|
||||
respond_tag_only=True,
|
||||
)
|
||||
insert_slack_channel_config(
|
||||
@@ -371,7 +368,9 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Fetches all channels in the Slack workspace using the conversations_list API.
|
||||
This includes both public and private channels that are visible to the app,
|
||||
not just the ones the bot is a member of.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
@@ -386,20 +385,20 @@ def get_all_channels_from_slack_api(
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
# Use users_conversations with limited pagination
|
||||
# Use conversations_list to get all channels in the workspace (including ones the bot is not a member of)
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.users_conversations(
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.users_conversations(
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
|
||||
@@ -102,6 +102,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
if user_to_update.id == current_user.id:
|
||||
@@ -122,6 +123,22 @@ def set_user_role(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestUpsertRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
@router.post("/manage/users/test-upsert-user")
|
||||
async def test_upsert_user(
|
||||
request: TestUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None | FullUserSnapshot:
|
||||
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
|
||||
user = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.saml", "upsert_saml_user", None
|
||||
)(email=request.email)
|
||||
return FullUserSnapshot.from_user_model(user) if user else None
|
||||
|
||||
|
||||
@router.get("/manage/users/accepted")
|
||||
def list_accepted_users(
|
||||
q: str | None = Query(default=None),
|
||||
@@ -296,7 +313,7 @@ def bulk_invite_users(
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
@@ -318,7 +335,7 @@ def bulk_invite_users(
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
if not MULTI_TENANT or DEV_MODE:
|
||||
return number_of_invited_users
|
||||
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
@@ -359,7 +376,7 @@ def remove_invited_user(
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
import io
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.file import OnyxStaticFileManager
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -87,3 +96,72 @@ class OnyxRuntime:
|
||||
)
|
||||
|
||||
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
|
||||
|
||||
@staticmethod
|
||||
def get_beat_multiplier() -> float:
|
||||
"""the beat multiplier is used to scale up or down the frequency of certain beat
|
||||
tasks in the cloud. It has a significant effect on load and is useful to adjust
|
||||
in real time."""
|
||||
|
||||
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if beat_multiplier <= 0.0:
|
||||
return 1.0
|
||||
|
||||
return beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def get_doc_permission_sync_multiplier() -> float:
|
||||
"""Permission syncs are a significant source of load / queueing in the cloud."""
|
||||
|
||||
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
|
||||
if value_raw is not None:
|
||||
try:
|
||||
value_bytes = cast(bytes, value_raw)
|
||||
value = float(value_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if value <= 0.0:
|
||||
return 1.0
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def get_build_fence_lookup_table_interval() -> int:
|
||||
"""We maintain an active fence table to make lookups of existing fences efficient.
|
||||
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
|
||||
"""
|
||||
|
||||
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
interval_raw = r.get(
|
||||
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
|
||||
)
|
||||
if interval_raw is not None:
|
||||
try:
|
||||
interval_bytes = cast(bytes, interval_raw)
|
||||
interval = int(interval_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if interval <= 0.0:
|
||||
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
|
||||
return interval
|
||||
|
||||
@@ -36,6 +36,10 @@ class RecordType(str, Enum):
|
||||
LATENCY = "latency"
|
||||
FAILURE = "failure"
|
||||
METRIC = "metric"
|
||||
INDEXING_PROGRESS = "indexing_progress"
|
||||
INDEXING_COMPLETE = "indexing_complete"
|
||||
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
|
||||
INDEX_ATTEMPT_STATUS = "index_attempt_status"
|
||||
|
||||
|
||||
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
|
||||
@@ -6,14 +6,17 @@ import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import MutableMapping
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import FIRST_COMPLETED
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import wait
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import overload
|
||||
from typing import Protocol
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
return collections.abc.ValuesView(self)
|
||||
|
||||
|
||||
class CallableProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
|
||||
allow_failures: bool = False,
|
||||
max_workers: int | None = None,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
This function preserves contextvars across threads, which is important for maintaining
|
||||
context like tenant IDs in database sessions.
|
||||
|
||||
Args:
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
|
||||
max_workers: Max number of worker threads
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping function names to their results or error messages.
|
||||
list: A list of results from each function, in the same order as the input functions.
|
||||
"""
|
||||
workers = (
|
||||
min(max_workers, len(functions_with_args))
|
||||
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logger.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None))
|
||||
results.append((index, None)) # type: ignore
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
@@ -288,7 +298,7 @@ def run_with_timeout(
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
return task.result # type: ignore
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
@@ -304,9 +314,9 @@ def run_in_background(
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
|
||||
task.start()
|
||||
return task
|
||||
return cast(TimeoutThread[R], task)
|
||||
|
||||
|
||||
def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
|
||||
@@ -56,7 +56,7 @@ puremagic==1.28
|
||||
pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.8.2
|
||||
PyGithub==1.58.2
|
||||
PyGithub==2.5.0
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==3.9.0
|
||||
python-pptx==0.6.23
|
||||
|
||||
77
backend/tests/daily/connectors/blob/test_blob_connector.py
Normal file
77
backend/tests/daily/connectors/blob/test_blob_connector.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
|
||||
"aws_secret_access_key": os.environ[
|
||||
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_blob_s3_connector(
|
||||
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
|
||||
) -> None:
|
||||
"""
|
||||
Plain and document file types should be fully indexed.
|
||||
|
||||
Multimedia and unknown file types will be indexed by title only with one empty section.
|
||||
|
||||
This is intentional in order to allow searching by just the title even if we can't
|
||||
index the file content.
|
||||
"""
|
||||
all_docs: list[Document] = []
|
||||
document_batches = blob_connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
|
||||
#
|
||||
assert len(all_docs) == 19
|
||||
|
||||
for doc in all_docs:
|
||||
section = doc.sections[0]
|
||||
assert isinstance(section, TextSection)
|
||||
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
assert len(section.text) == 0
|
||||
continue
|
||||
|
||||
# unknown extension
|
||||
assert len(section.text) == 0
|
||||
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
54
backend/tests/daily/connectors/github/test_github_basic.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector() -> GithubConnector:
|
||||
connector = GithubConnector(
|
||||
repo_owner="onyx-dot-app",
|
||||
repositories="documentation",
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_github_connector_basic(github_connector: GithubConnector) -> None:
|
||||
docs = load_all_docs_from_checkpoint_connector(
|
||||
connector=github_connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) > 0 # We expect at least one PR to exist
|
||||
|
||||
# Test the first document's structure
|
||||
doc = docs[0]
|
||||
|
||||
# Verify basic document properties
|
||||
assert doc.source == DocumentSource.GITHUB
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
# Verify GitHub-specific properties
|
||||
assert "github.com" in doc.id # Should be a GitHub URL
|
||||
assert doc.metadata is not None
|
||||
assert "state" in doc.metadata
|
||||
assert "merged" in doc.metadata
|
||||
|
||||
# Verify sections
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.link == doc.id # Section link should match document ID
|
||||
assert isinstance(section.text, str) # Should have some text content
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
@@ -136,3 +137,22 @@ def google_drive_service_acct_connector_factory() -> (
|
||||
return connector
|
||||
|
||||
return _connector_factory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def set_resource_limits() -> None:
|
||||
# the google sdk is aggressive about using up file descriptors and
|
||||
# macos is stingy ... these tests will fail randomly unless the descriptor limit is raised
|
||||
RLIMIT_MINIMUM = 2048
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
desired_soft = min(RLIMIT_MINIMUM, hard) # Pick your target here
|
||||
|
||||
print(f"Open file limit: soft={soft} hard={hard} soft_required={RLIMIT_MINIMUM}")
|
||||
|
||||
if soft < desired_soft:
|
||||
print(f"Raising open file limit: {soft} -> {desired_soft}")
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_soft, hard))
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
print(f"New open file limit: soft={soft} hard={hard}")
|
||||
return
|
||||
|
||||
@@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
|
||||
)
|
||||
|
||||
EXTERNAL_SHARED_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
|
||||
)
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
|
||||
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
|
||||
]
|
||||
EXTERNAL_SHARED_DOC_SINGLETON = (
|
||||
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
|
||||
)
|
||||
|
||||
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
|
||||
|
||||
ADMIN_EMAIL = "admin@onyx-test.com"
|
||||
@@ -161,10 +171,14 @@ def _get_expected_file_content(file_id: int) -> str:
|
||||
return file_text_template.format(file_id)
|
||||
|
||||
|
||||
def assert_retrieved_docs_match_expected(
|
||||
def assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs: list[Document],
|
||||
expected_file_ids: Sequence[int],
|
||||
) -> None:
|
||||
"""NOTE: as far as i can tell this does NOT assert for an exact match.
|
||||
it only checks to see if that the expected file id's are IN the retrieved doc list
|
||||
"""
|
||||
|
||||
expected_file_names = {
|
||||
file_name_template.format(file_id) for file_id in expected_file_ids
|
||||
}
|
||||
@@ -175,7 +189,7 @@ def assert_retrieved_docs_match_expected(
|
||||
retrieved_docs.sort(key=lambda x: x.semantic_identifier)
|
||||
|
||||
for doc in retrieved_docs:
|
||||
print(f"doc.semantic_identifier: {doc.semantic_identifier}")
|
||||
print(f"retrieved doc: doc.semantic_identifier={doc.semantic_identifier}")
|
||||
|
||||
# Filter out invalid prefixes to prevent different tests from interfering with each other
|
||||
valid_retrieved_docs = [
|
||||
|
||||
@@ -7,7 +7,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -62,7 +62,7 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -128,7 +128,7 @@ def test_include_my_drives_only(
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -161,7 +161,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -198,7 +198,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -241,7 +241,7 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -271,7 +271,7 @@ def test_personal_folders_only(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOC_SINGLETON,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
@@ -70,12 +80,40 @@ def test_include_all(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_include_shared_drives_only_with_size_threshold(
|
||||
mock_get_api_key: MagicMock,
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_include_shared_drives_only_with_size_threshold")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=True,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_folder_urls=None,
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
# this threshold will skip one file
|
||||
connector.size_threshold = 16384
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 52
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
@@ -94,6 +132,7 @@ def test_include_shared_drives_only(
|
||||
shared_drive_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
@@ -108,7 +147,11 @@ def test_include_shared_drives_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 53
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -142,7 +185,7 @@ def test_include_my_drives_only(
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -176,7 +219,7 @@ def test_drive_one_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -214,7 +257,7 @@ def test_folder_and_shared_drive(
|
||||
+ FOLDER_2_1_FILE_IDS
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -257,12 +300,70 @@ def test_folders_only(
|
||||
+ FOLDER_2_2_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
|
||||
def test_shared_folder_owned_by_external_user(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_folder_owned_by_external_user")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
|
||||
|
||||
assert len(retrieved_docs) == len(expected_docs) # 1 for now
|
||||
assert expected_docs[0] in retrieved_docs[0].id
|
||||
|
||||
|
||||
def test_shared_with_me(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_with_me")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=True,
|
||||
include_files_shared_with_me=True,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
print(retrieved_docs)
|
||||
|
||||
expected_file_ids = (
|
||||
ADMIN_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ TEST_USER_1_FILE_IDS
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
|
||||
for id in retrieved_ids:
|
||||
print(id)
|
||||
|
||||
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
|
||||
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
@@ -288,7 +389,7 @@ def test_specific_emails(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -318,7 +419,7 @@ def get_specific_folders_in_my_drive(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_retrieved_docs_match_expected,
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
@@ -50,7 +50,7 @@ def test_all(
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -83,7 +83,7 @@ def test_shared_drives_only(
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -114,7 +114,7 @@ def test_shared_with_me_only(
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -142,7 +142,7 @@ def test_my_drive_only(
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -172,7 +172,7 @@ def test_shared_my_drive_folder(
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
@@ -199,7 +199,7 @@ def test_shared_drive_folder(
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
|
||||
assert_retrieved_docs_match_expected(
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,14 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_zendesk_data.json") -> dict[str, dict]:
|
||||
@@ -50,7 +52,7 @@ def get_credentials() -> dict[str, str]:
|
||||
def test_zendesk_connector_basic(
|
||||
request: pytest.FixtureRequest, connector_fixture: str
|
||||
) -> None:
|
||||
connector = request.getfixturevalue(connector_fixture)
|
||||
connector = cast(ZendeskConnector, request.getfixturevalue(connector_fixture))
|
||||
test_data = load_test_data()
|
||||
all_docs: list[Document] = []
|
||||
target_test_doc_id: str
|
||||
@@ -61,12 +63,11 @@ def test_zendesk_connector_basic(
|
||||
|
||||
target_doc: Document | None = None
|
||||
|
||||
for doc_batch in connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
for doc in load_all_docs_from_checkpoint_connector(connector, 0, time.time()):
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_doc = doc
|
||||
print(f"target_doc {target_doc}")
|
||||
|
||||
assert len(all_docs) > 0, "No documents were retrieved from the connector"
|
||||
assert (
|
||||
@@ -111,8 +112,10 @@ def test_zendesk_connector_basic(
|
||||
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
|
||||
# Get full doc IDs
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in zendesk_article_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
for doc in load_all_docs_from_checkpoint_connector(
|
||||
zendesk_article_connector, 0, time.time()
|
||||
):
|
||||
all_full_doc_ids.add(doc.id)
|
||||
|
||||
# Get slim doc IDs
|
||||
all_slim_doc_ids = set()
|
||||
|
||||
@@ -9,7 +9,9 @@ from requests import HTTPError
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -123,10 +125,15 @@ class UserManager:
|
||||
user_to_set: DATestUser,
|
||||
target_role: UserRole,
|
||||
user_performing_action: DATestUser,
|
||||
explicit_override: bool = False,
|
||||
) -> DATestUser:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/set-user-role",
|
||||
json={"user_email": user_to_set.email, "new_role": target_role.value},
|
||||
json={
|
||||
"user_email": user_to_set.email,
|
||||
"new_role": target_role.value,
|
||||
"explicit_override": explicit_override,
|
||||
},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -240,3 +247,69 @@ class UserManager:
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
return paginated_result
|
||||
|
||||
@staticmethod
|
||||
def invite_user(
|
||||
user_to_invite_email: str, user_performing_action: DATestUser
|
||||
) -> None:
|
||||
"""Invite a user by email to join the organization.
|
||||
|
||||
Args:
|
||||
user_to_invite_email: Email of the user to invite
|
||||
user_performing_action: User with admin permissions performing the invitation
|
||||
"""
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/users",
|
||||
headers=user_performing_action.headers,
|
||||
json={"emails": [user_to_invite_email]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def accept_invitation(tenant_id: str, user_performing_action: DATestUser) -> None:
|
||||
"""Accept an invitation to join the organization.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant/organization to accept invitation for
|
||||
user_performing_action: User accepting the invitation
|
||||
"""
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/users/invite/accept",
|
||||
headers=user_performing_action.headers,
|
||||
json={"tenant_id": tenant_id},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_invited_users(
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[InvitedUserSnapshot]:
|
||||
"""Get a list of all invited users.
|
||||
|
||||
Args:
|
||||
user_performing_action: User with admin permissions performing the action
|
||||
|
||||
Returns:
|
||||
List of invited user snapshots
|
||||
"""
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users/invited",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return [InvitedUserSnapshot(**user) for user in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def get_user_info(user_performing_action: DATestUser) -> UserInfo:
|
||||
"""Get user info for the current user.
|
||||
|
||||
Args:
|
||||
user_performing_action: User performing the action
|
||||
"""
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return UserInfo(**response.json())
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
INVITED_BASIC_USER = "basic_user"
|
||||
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
|
||||
|
||||
|
||||
def test_user_invitation_flow(reset_multitenant: None) -> None:
|
||||
# Create first user (admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin")
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create second user
|
||||
invited_user: DATestUser = UserManager.create(name="admin_invited")
|
||||
assert UserManager.is_role(invited_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites the previously registered and non-registered user
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
|
||||
|
||||
invited_basic_user: DATestUser = UserManager.create(
|
||||
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
|
||||
)
|
||||
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
|
||||
|
||||
# Verify the user is in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email in [
|
||||
user.email for user in invited_users
|
||||
], f"User {invited_user.email} not found in invited users list"
|
||||
|
||||
# Get user info to check tenant information
|
||||
user_info = UserManager.get_user_info(invited_user)
|
||||
|
||||
# Extract the tenant_id from the invitation
|
||||
invited_tenant_id = (
|
||||
user_info.tenant_info.invitation.tenant_id
|
||||
if user_info.tenant_info and user_info.tenant_info.invitation
|
||||
else None
|
||||
)
|
||||
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
|
||||
|
||||
UserManager.accept_invitation(invited_tenant_id, invited_user)
|
||||
|
||||
# Get updated user info after accepting invitation
|
||||
updated_user_info = UserManager.get_user_info(invited_user)
|
||||
|
||||
# Verify the user is no longer in the invited users list
|
||||
updated_invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email not in [
|
||||
user.email for user in updated_invited_users
|
||||
], f"User {invited_user.email} should not be in invited users list after accepting"
|
||||
|
||||
# Verify the user has BASIC role in the organization
|
||||
assert (
|
||||
updated_user_info.role == UserRole.BASIC
|
||||
), f"Expected user to have BASIC role, but got {updated_user_info.role}"
|
||||
|
||||
# Verify user is in the organization
|
||||
user_page = UserManager.get_user_page(
|
||||
user_performing_action=admin_user, role_filter=[UserRole.BASIC]
|
||||
)
|
||||
|
||||
# Check if the invited user is in the list of users with BASIC role
|
||||
invited_user_emails = [user.email for user in user_page.items]
|
||||
assert invited_user.email in invited_user_emails, (
|
||||
f"User {invited_user.email} not found in the list of basic users "
|
||||
f"in the organization. Available users: {invited_user_emails}"
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_saml_user_conversion(reset: None) -> None:
|
||||
"""
|
||||
Test that SAML login correctly converts users with non-authenticated roles
|
||||
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
|
||||
|
||||
This test:
|
||||
1. Creates an admin and a regular user
|
||||
2. Changes the regular user's role to EXT_PERM_USER
|
||||
3. Simulates a SAML login by calling the test endpoint
|
||||
4. Verifies the user's role is converted to BASIC
|
||||
|
||||
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
|
||||
are properly converted to authenticated roles during SAML login.
|
||||
"""
|
||||
# Create an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
|
||||
|
||||
# Create a regular user that we'll convert to EXT_PERM_USER
|
||||
test_user_email = "ext_perm_user@example.com"
|
||||
test_user = UserManager.create(email=test_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to EXT_PERM_USER using the UserManager
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has EXT_PERM_USER role now
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Do the same test with SLACK_USER
|
||||
slack_user_email = "slack_user@example.com"
|
||||
slack_user = UserManager.create(email=slack_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to SLACK_USER
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has SLACK_USER role
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user() -> MagicMock:
|
||||
"""Creates a mock User instance for testing."""
|
||||
user = MagicMock(spec=User)
|
||||
user.email = "test@example.com"
|
||||
user.id = "test-user-id"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_account() -> MagicMock:
|
||||
"""Creates a mock OAuthAccount instance for testing."""
|
||||
oauth_account = MagicMock(spec=OAuthAccount)
|
||||
oauth_account.oauth_name = "google"
|
||||
oauth_account.refresh_token = "test-refresh-token"
|
||||
oauth_account.access_token = "test-access-token"
|
||||
oauth_account.expires_at = None
|
||||
return oauth_account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_manager() -> MagicMock:
|
||||
"""Creates a mock user manager for testing."""
|
||||
user_manager = MagicMock()
|
||||
user_manager.user_db = MagicMock()
|
||||
user_manager.user_db.update_oauth_account = AsyncMock()
|
||||
user_manager.user_db.update = AsyncMock()
|
||||
return user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
"""Creates a mock database session for testing."""
|
||||
return MagicMock()
|
||||
273
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
273
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.oauth_refresher import _test_expire_oauth_token
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
|
||||
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
|
||||
from onyx.auth.oauth_refresher import refresh_oauth_token
|
||||
from onyx.db.models import OAuthAccount
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_success(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test successful OAuth token refresh."""
|
||||
# Mock HTTP client and response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Use fixture values but ensure refresh token exists
|
||||
mock_oauth_account.oauth_name = (
|
||||
"google" # Ensure it's google to match the refresh endpoint
|
||||
)
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_instance = mock_client
|
||||
client_class_mock.return_value.__aenter__.return_value = client_instance
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify token data was updated correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert update_data["access_token"] == "new_token"
|
||||
assert update_data["refresh_token"] == "new_refresh_token"
|
||||
assert "expires_at" in update_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_failure(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> bool:
|
||||
"""Test OAuth token refresh failure due to HTTP error."""
|
||||
# Mock HTTP client with error response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400 # Simulate error
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Ensure refresh token exists and provider is supported
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_class_mock.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_not_called()
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_no_refresh_token(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test OAuth token refresh when no refresh token is available."""
|
||||
# Set refresh token to None
|
||||
mock_oauth_account.refresh_token = None
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
|
||||
# No need to mock httpx since it shouldn't be called
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_refresh_oauth_tokens(
|
||||
mock_user: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test checking and refreshing multiple OAuth tokens."""
|
||||
# Create mock user with OAuth accounts
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Create an account that needs refreshing (expiring soon)
|
||||
expiring_account = MagicMock(spec=OAuthAccount)
|
||||
expiring_account.oauth_name = "google"
|
||||
expiring_account.refresh_token = "refresh_token_1"
|
||||
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
|
||||
|
||||
# Create an account that doesn't need refreshing (expires later)
|
||||
valid_account = MagicMock(spec=OAuthAccount)
|
||||
valid_account.oauth_name = "google"
|
||||
valid_account.refresh_token = "refresh_token_2"
|
||||
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
|
||||
|
||||
# Create an account without a refresh token
|
||||
no_refresh_account = MagicMock(spec=OAuthAccount)
|
||||
no_refresh_account.oauth_name = "google"
|
||||
no_refresh_account.refresh_token = None
|
||||
no_refresh_account.expires_at = (
|
||||
now_timestamp + 60
|
||||
) # Expiring soon but no refresh token
|
||||
|
||||
# Set oauth_accounts on the mock user
|
||||
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
|
||||
|
||||
# Mock refresh_oauth_token function
|
||||
with patch(
|
||||
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
|
||||
) as mock_refresh:
|
||||
# Call the function under test
|
||||
await check_and_refresh_oauth_tokens(
|
||||
mock_user, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
|
||||
# Check it was called with the expiring account
|
||||
mock_refresh.assert_called_once_with(
|
||||
mock_user, expiring_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
|
||||
"""Test identifying OAuth accounts that need refresh tokens."""
|
||||
# Create accounts with and without refresh tokens
|
||||
account_with_token = MagicMock(spec=OAuthAccount)
|
||||
account_with_token.oauth_name = "google"
|
||||
account_with_token.refresh_token = "refresh_token"
|
||||
|
||||
account_without_token = MagicMock(spec=OAuthAccount)
|
||||
account_without_token.oauth_name = "google"
|
||||
account_without_token.refresh_token = None
|
||||
|
||||
second_account_without_token = MagicMock(spec=OAuthAccount)
|
||||
second_account_without_token.oauth_name = "github"
|
||||
second_account_without_token.refresh_token = (
|
||||
"" # Empty string should also be treated as missing
|
||||
)
|
||||
|
||||
# Set accounts on user
|
||||
mock_user.oauth_accounts = [
|
||||
account_with_token,
|
||||
account_without_token,
|
||||
second_account_without_token,
|
||||
]
|
||||
|
||||
# Call the function under test
|
||||
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
|
||||
mock_user
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(accounts_needing_refresh) == 2
|
||||
assert account_without_token in accounts_needing_refresh
|
||||
assert second_account_without_token in accounts_needing_refresh
|
||||
assert account_with_token not in accounts_needing_refresh
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_oauth_account_has_refresh_token(
|
||||
mock_user: MagicMock, mock_oauth_account: MagicMock
|
||||
) -> None:
|
||||
"""Test checking if an OAuth account has a refresh token."""
|
||||
# Test with refresh token
|
||||
mock_oauth_account.refresh_token = "refresh_token"
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is True
|
||||
|
||||
# Test with None refresh token
|
||||
mock_oauth_account.refresh_token = None
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
# Test with empty string refresh token
|
||||
mock_oauth_account.refresh_token = ""
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_expire_oauth_token(
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test the testing utility function for token expiration."""
|
||||
# Set up the mock account
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "test_refresh_token"
|
||||
mock_oauth_account.access_token = "test_access_token"
|
||||
|
||||
# Call the function under test
|
||||
result = await _test_expire_oauth_token(
|
||||
mock_user,
|
||||
mock_oauth_account,
|
||||
mock_db_session,
|
||||
mock_user_manager,
|
||||
expire_in_seconds=10,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify the expiration time was set correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert "expires_at" in update_data
|
||||
|
||||
# Now should be within 10-11 seconds of the set expiration
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
assert update_data["expires_at"] - now >= 8.9 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now <= 11.1 # Allow 1 second for test execution
|
||||
@@ -0,0 +1,441 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from github import Github
|
||||
from github import GithubException
|
||||
from github import RateLimitExceededException
|
||||
from github.Issue import Issue
|
||||
from github.PullRequest import PullRequest
|
||||
from github.RateLimit import RateLimit
|
||||
from github.Repository import Repository
|
||||
from github.Requester import Requester
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import SerializedRepository
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repo_owner() -> str:
|
||||
return "test-org"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repositories() -> str:
|
||||
return "test-repo"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_client() -> MagicMock:
|
||||
"""Create a mock GitHub client with proper typing"""
|
||||
mock = MagicMock(spec=Github)
|
||||
# Add proper return typing for get_repo method
|
||||
mock.get_repo = MagicMock(return_value=MagicMock(spec=Repository))
|
||||
# Add proper return typing for get_organization method
|
||||
mock.get_organization = MagicMock()
|
||||
# Add proper return typing for get_user method
|
||||
mock.get_user = MagicMock()
|
||||
# Add proper return typing for get_rate_limit method
|
||||
mock.get_rate_limit = MagicMock(return_value=MagicMock(spec=RateLimit))
|
||||
# Add requester for repository deserialization
|
||||
mock.requester = MagicMock(spec=Requester)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_connector(
|
||||
repo_owner: str, repositories: str, mock_github_client: MagicMock
|
||||
) -> Generator[GithubConnector, None, None]:
|
||||
connector = GithubConnector(
|
||||
repo_owner=repo_owner,
|
||||
repositories=repositories,
|
||||
include_prs=True,
|
||||
include_issues=True,
|
||||
)
|
||||
connector.github_client = mock_github_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_pr() -> Callable[..., MagicMock]:
|
||||
def _create_mock_pr(
|
||||
number: int = 1,
|
||||
title: str = "Test PR",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
merged: bool = False,
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock PullRequest object"""
|
||||
mock_pr = MagicMock(spec=PullRequest)
|
||||
mock_pr.number = number
|
||||
mock_pr.title = title
|
||||
mock_pr.body = body
|
||||
mock_pr.state = state
|
||||
mock_pr.merged = merged
|
||||
mock_pr.updated_at = updated_at
|
||||
mock_pr.html_url = f"https://github.com/test-org/test-repo/pull/{number}"
|
||||
return mock_pr
|
||||
|
||||
return _create_mock_pr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
def _create_mock_issue(
|
||||
number: int = 1,
|
||||
title: str = "Test Issue",
|
||||
body: str = "Test Description",
|
||||
state: str = "open",
|
||||
updated_at: datetime = datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Issue object"""
|
||||
mock_issue = MagicMock(spec=Issue)
|
||||
mock_issue.number = number
|
||||
mock_issue.title = title
|
||||
mock_issue.body = body
|
||||
mock_issue.state = state
|
||||
mock_issue.updated_at = updated_at
|
||||
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
|
||||
mock_issue.pull_request = None # Not a PR
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_repo() -> Callable[..., MagicMock]:
|
||||
def _create_mock_repo(
|
||||
name: str = "test-repo",
|
||||
id: int = 1,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock Repository object"""
|
||||
mock_repo = MagicMock(spec=Repository)
|
||||
mock_repo.name = name
|
||||
mock_repo.id = id
|
||||
mock_repo.raw_headers = {"status": "200 OK", "content-type": "application/json"}
|
||||
mock_repo.raw_data = {
|
||||
"id": str(id),
|
||||
"name": name,
|
||||
"full_name": f"test-org/{name}",
|
||||
"private": str(False),
|
||||
"description": "Test repository",
|
||||
}
|
||||
return mock_repo
|
||||
|
||||
return _create_mock_repo
|
||||
|
||||
|
||||
def test_load_from_checkpoint_happy_path(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint - happy path"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs and issues
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_pulls and get_issues methods
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got all documents and final has_more=False
|
||||
assert len(outputs) == 4
|
||||
|
||||
repo_batch = outputs[0]
|
||||
assert len(repo_batch.items) == 0
|
||||
assert repo_batch.next_checkpoint.has_more is True
|
||||
|
||||
# Check first batch (PRs)
|
||||
first_batch = outputs[1]
|
||||
assert len(first_batch.items) == 2
|
||||
assert isinstance(first_batch.items[0], Document)
|
||||
assert first_batch.items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
assert isinstance(first_batch.items[1], Document)
|
||||
assert first_batch.items[1].id == "https://github.com/test-org/test-repo/pull/2"
|
||||
assert first_batch.next_checkpoint.curr_page == 1
|
||||
|
||||
# Check second batch (Issues)
|
||||
second_batch = outputs[2]
|
||||
assert len(second_batch.items) == 2
|
||||
assert isinstance(second_batch.items[0], Document)
|
||||
assert (
|
||||
second_batch.items[0].id == "https://github.com/test-org/test-repo/issues/1"
|
||||
)
|
||||
assert isinstance(second_batch.items[1], Document)
|
||||
assert (
|
||||
second_batch.items[1].id == "https://github.com/test-org/test-repo/issues/2"
|
||||
)
|
||||
assert second_batch.next_checkpoint.has_more
|
||||
|
||||
# Check third batch (finished checkpoint)
|
||||
third_batch = outputs[3]
|
||||
assert len(third_batch.items) == 0
|
||||
assert third_batch.next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PR
|
||||
mock_pr = create_mock_pr()
|
||||
|
||||
# Mock get_pulls to raise RateLimitExceededException on first call
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
RateLimitExceededException(403, {"message": "Rate limit exceeded"}, {}),
|
||||
[mock_pr],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock rate limit reset time
|
||||
mock_rate_limit = MagicMock(spec=RateLimit)
|
||||
mock_rate_limit.core.reset = datetime.now(timezone.utc)
|
||||
github_connector.github_client.get_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch(
|
||||
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
|
||||
) as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "https://github.com/test-org/test-repo/pull/1"
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_repo(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an empty repository"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_pulls and get_issues to return empty lists
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.return_value = []
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.return_value = []
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert len(outputs[-1].items) == 0
|
||||
assert not outputs[-1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_prs_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_pr: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only PRs enabled"""
|
||||
# Configure connector to only include PRs
|
||||
github_connector.include_prs = True
|
||||
github_connector.include_issues = False
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked PRs
|
||||
mock_pr1 = create_mock_pr(number=1, title="PR 1")
|
||||
mock_pr2 = create_mock_pr(number=2, title="PR 2")
|
||||
|
||||
# Mock get_pulls method
|
||||
mock_repo.get_pulls.return_value = MagicMock()
|
||||
mock_repo.get_pulls.return_value.get_page.side_effect = [
|
||||
[mock_pr1, mock_pr2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got PRs
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "pull" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be PRs
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_issues_only(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
create_mock_issue: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with only issues enabled"""
|
||||
# Configure connector to only include issues
|
||||
github_connector.include_prs = False
|
||||
github_connector.include_issues = True
|
||||
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Set up mocked issues
|
||||
mock_issue1 = create_mock_issue(number=1, title="Issue 1")
|
||||
mock_issue2 = create_mock_issue(number=2, title="Issue 2")
|
||||
|
||||
# Mock get_issues method
|
||||
mock_repo.get_issues.return_value = MagicMock()
|
||||
mock_repo.get_issues.return_value.get_page.side_effect = [
|
||||
[mock_issue1, mock_issue2],
|
||||
[],
|
||||
]
|
||||
|
||||
# Mock SerializedRepository.to_Repository to return our mock repo
|
||||
with patch.object(SerializedRepository, "to_Repository", return_value=mock_repo):
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
)
|
||||
|
||||
# Check that we only got issues
|
||||
assert len(outputs) >= 2
|
||||
assert len(outputs[1].items) == 2
|
||||
assert all(
|
||||
isinstance(doc, Document) and "issues" in doc.id for doc in outputs[0].items
|
||||
) # All documents should be issues
|
||||
assert outputs[1].next_checkpoint.has_more
|
||||
|
||||
assert outputs[-1].next_checkpoint.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"GitHub credential appears to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your GitHub token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"GitHub repository not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
github_connector: GithubConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
error = GithubException(status=status_code, data={}, headers={})
|
||||
|
||||
github_client = cast(Github, github_connector.github_client)
|
||||
get_repo_mock = cast(MagicMock, github_client.get_repo)
|
||||
get_repo_mock.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
github_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
github_connector: GithubConnector,
|
||||
mock_github_client: MagicMock,
|
||||
create_mock_repo: Callable[..., MagicMock],
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Set up mocked repo
|
||||
mock_repo = create_mock_repo()
|
||||
github_connector.github_client = mock_github_client
|
||||
mock_github_client.get_repo.return_value = mock_repo
|
||||
|
||||
# Mock get_contents to simulate successful access
|
||||
mock_repo.get_contents.return_value = MagicMock()
|
||||
|
||||
github_connector.validate_connector_settings()
|
||||
github_connector.github_client.get_repo.assert_called_once_with(
|
||||
f"{github_connector.repo_owner}/{github_connector.repositories}"
|
||||
)
|
||||
@@ -0,0 +1,472 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import call
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.zendesk.connector import ZendeskClient
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_zendesk_client() -> MagicMock:
|
||||
"""Create a mock Zendesk client"""
|
||||
mock = MagicMock(spec=ZendeskClient)
|
||||
mock.base_url = "https://test.zendesk.com/api/v2"
|
||||
mock.auth = ("test@example.com/token", "test_token")
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zendesk_connector(
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with mocked client"""
|
||||
connector = ZendeskConnector(content_type="articles")
|
||||
connector.client = mock_zendesk_client
|
||||
yield connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unmocked_zendesk_connector() -> Generator[ZendeskConnector, None, None]:
|
||||
"""Create a Zendesk connector with unmocked client"""
|
||||
zendesk_connector = ZendeskConnector(content_type="articles")
|
||||
zendesk_connector.client = ZendeskClient(
|
||||
"test", "test@example.com/token", "test_token"
|
||||
)
|
||||
yield zendesk_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_article() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_article(
|
||||
id: int = 1,
|
||||
title: str = "Test Article",
|
||||
body: str = "Test Content",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
author_id: str = "123",
|
||||
label_names: list[str] | None = None,
|
||||
draft: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock article"""
|
||||
return {
|
||||
"id": id,
|
||||
"title": title,
|
||||
"body": body,
|
||||
"updated_at": updated_at,
|
||||
"author_id": author_id,
|
||||
"label_names": label_names or [],
|
||||
"draft": draft,
|
||||
"html_url": f"https://test.zendesk.com/hc/en-us/articles/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_article
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_ticket() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_ticket(
|
||||
id: int = 1,
|
||||
subject: str = "Test Ticket",
|
||||
description: str = "Test Description",
|
||||
updated_at: str = "2023-01-01T12:00:00Z",
|
||||
submitter_id: str = "123",
|
||||
status: str = "open",
|
||||
priority: str = "normal",
|
||||
tags: list[str] | None = None,
|
||||
ticket_type: str = "question",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock ticket"""
|
||||
return {
|
||||
"id": id,
|
||||
"subject": subject,
|
||||
"description": description,
|
||||
"updated_at": updated_at,
|
||||
"submitter": submitter_id,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"tags": tags or [],
|
||||
"type": ticket_type,
|
||||
"url": f"https://test.zendesk.com/agent/tickets/{id}",
|
||||
}
|
||||
|
||||
return _create_mock_ticket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_mock_author() -> Callable[..., dict[str, Any]]:
|
||||
def _create_mock_author(
|
||||
id: str = "123",
|
||||
name: str = "Test User",
|
||||
email: str = "test@example.com",
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to create a mock author"""
|
||||
return {
|
||||
"user": {
|
||||
"id": id,
|
||||
"name": name,
|
||||
"email": email,
|
||||
}
|
||||
}
|
||||
|
||||
return _create_mock_author
|
||||
|
||||
|
||||
def test_load_from_checkpoint_articles_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading articles from checkpoint - happy path"""
|
||||
# Set up mock responses
|
||||
mock_article1 = create_mock_article(id=1, title="Article 1")
|
||||
mock_article2 = create_mock_article(id=2, title="Article 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page
|
||||
{
|
||||
"articles": [mock_article1, mock_article2],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "article:1"
|
||||
assert doc1.semantic_identifier == "Article 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "article:2"
|
||||
assert doc2.semantic_identifier == "Article 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_tickets_happy_path(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading tickets from checkpoint - happy path"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses
|
||||
mock_ticket1 = create_mock_ticket(id=1, subject="Ticket 1")
|
||||
mock_ticket2 = create_mock_ticket(id=2, subject="Ticket 2")
|
||||
mock_author = create_mock_author()
|
||||
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page
|
||||
{
|
||||
"tickets": [mock_ticket1, mock_ticket2],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
# Third call: author info
|
||||
mock_author,
|
||||
# Fourth call: comments page
|
||||
{"comments": []},
|
||||
# Fifth call: comments page
|
||||
{"comments": []},
|
||||
]
|
||||
|
||||
zendesk_connector.client = mock_zendesk_client
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got the documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 2
|
||||
|
||||
# Check first document
|
||||
doc1 = outputs[1].items[0]
|
||||
print(doc1, type(doc1))
|
||||
assert isinstance(doc1, Document)
|
||||
assert doc1.id == "zendesk_ticket_1"
|
||||
assert doc1.semantic_identifier == "Ticket #1: Ticket 1"
|
||||
assert doc1.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check second document
|
||||
doc2 = outputs[1].items[1]
|
||||
assert isinstance(doc2, Document)
|
||||
assert doc2.id == "zendesk_ticket_2"
|
||||
assert doc2.semantic_identifier == "Ticket #2: Ticket 2"
|
||||
assert doc2.source == DocumentSource.ZENDESK
|
||||
|
||||
# Check checkpoint state
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_rate_limit(
|
||||
unmocked_zendesk_connector: ZendeskConnector,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
create_mock_author: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with rate limit handling"""
|
||||
zendesk_connector = unmocked_zendesk_connector
|
||||
# Set up mock responses
|
||||
mock_article = create_mock_article()
|
||||
mock_author = create_mock_author()
|
||||
author_response = MagicMock()
|
||||
author_response.status_code = 200
|
||||
author_response.json.return_value = mock_author
|
||||
|
||||
# Create mock responses for requests.get
|
||||
rate_limit_response = MagicMock()
|
||||
rate_limit_response.status_code = 429
|
||||
rate_limit_response.headers = {"Retry-After": "60"}
|
||||
rate_limit_response.raise_for_status.side_effect = HTTPError(
|
||||
response=rate_limit_response
|
||||
)
|
||||
|
||||
success_response = MagicMock()
|
||||
success_response.status_code = 200
|
||||
success_response.json.return_value = {
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
}
|
||||
|
||||
# Mock requests.get to simulate rate limit then success
|
||||
with patch("onyx.connectors.zendesk.connector.requests.get") as mock_get:
|
||||
mock_get.side_effect = [
|
||||
# First call: content tags
|
||||
MagicMock(
|
||||
status_code=200,
|
||||
json=lambda: {"records": [], "meta": {"has_more": False}},
|
||||
),
|
||||
# Second call: articles page (rate limited)
|
||||
rate_limit_response,
|
||||
# Third call: articles page (after rate limit)
|
||||
success_response,
|
||||
# Fourth call: author info
|
||||
author_response,
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch("onyx.connectors.zendesk.connector.time.sleep") as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
zendesk_connector, 0, end_time
|
||||
)
|
||||
mock_sleep.assert_has_calls([call(60), call(0.1)])
|
||||
|
||||
# Check that we got the document after rate limit was handled
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 1
|
||||
assert isinstance(outputs[1].items[0], Document)
|
||||
assert outputs[1].items[0].id == "article:1"
|
||||
|
||||
# Verify the requests were made with correct parameters
|
||||
assert mock_get.call_count == 4
|
||||
# First call should be for content tags
|
||||
args, kwargs = mock_get.call_args_list[0]
|
||||
assert "guide/content_tags" in args[0]
|
||||
# Second call should be for articles (rate limited)
|
||||
args, kwargs = mock_get.call_args_list[1]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Third call should be for articles (success)
|
||||
args, kwargs = mock_get.call_args_list[2]
|
||||
assert "help_center/articles" in args[0]
|
||||
# Fourth call should be for author info
|
||||
args, kwargs = mock_get.call_args_list[3]
|
||||
assert "users/123" in args[0]
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_empty_response(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with empty response"""
|
||||
# Mock API responses
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: empty articles page
|
||||
{
|
||||
"articles": [],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that we got no documents
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_article(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_article: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with an article that should be skipped"""
|
||||
# Set up mock responses with a draft article
|
||||
mock_article = create_mock_article(draft=True)
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: articles page with draft article
|
||||
{
|
||||
"articles": [mock_article],
|
||||
"meta": {
|
||||
"has_more": False,
|
||||
"after_cursor": None,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_skipped_ticket(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
create_mock_ticket: Callable[..., dict[str, Any]],
|
||||
) -> None:
|
||||
"""Test loading from checkpoint with a deleted ticket"""
|
||||
# Configure connector for tickets
|
||||
zendesk_connector.content_type = "tickets"
|
||||
|
||||
# Set up mock responses with a deleted ticket
|
||||
mock_ticket = create_mock_ticket(status="deleted")
|
||||
mock_zendesk_client.make_request.side_effect = [
|
||||
# First call: content tags
|
||||
{"records": []},
|
||||
# Second call: tickets page with deleted ticket
|
||||
{
|
||||
"tickets": [mock_ticket],
|
||||
"end_of_stream": True,
|
||||
"end_time": int(time.time()),
|
||||
},
|
||||
]
|
||||
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
outputs = load_everything_from_checkpoint_connector(zendesk_connector, 0, end_time)
|
||||
|
||||
# Check that no documents were returned
|
||||
assert len(outputs) == 2
|
||||
assert outputs[0].next_checkpoint.cached_content_tags is not None
|
||||
assert len(outputs[1].items) == 0
|
||||
assert not outputs[1].next_checkpoint.has_more
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status_code,expected_exception,expected_message",
|
||||
[
|
||||
(
|
||||
401,
|
||||
CredentialExpiredError,
|
||||
"Your Zendesk credentials appear to be invalid or expired",
|
||||
),
|
||||
(
|
||||
403,
|
||||
InsufficientPermissionsError,
|
||||
"Your Zendesk token does not have sufficient permissions",
|
||||
),
|
||||
(
|
||||
404,
|
||||
ConnectorValidationError,
|
||||
"Zendesk resource not found",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_validate_connector_settings_errors(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
status_code: int,
|
||||
expected_exception: type[Exception],
|
||||
expected_message: str,
|
||||
) -> None:
|
||||
"""Test validation with various error scenarios"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
error = HTTPError(response=mock_response)
|
||||
|
||||
mock_zendesk_client = cast(MagicMock, zendesk_connector.client)
|
||||
mock_zendesk_client.make_request.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
print("excinfo", excinfo)
|
||||
zendesk_connector.validate_connector_settings()
|
||||
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
zendesk_connector: ZendeskConnector,
|
||||
mock_zendesk_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
# Mock successful API response
|
||||
mock_zendesk_client.make_request.return_value = {
|
||||
"articles": [],
|
||||
"meta": {"has_more": False},
|
||||
}
|
||||
|
||||
zendesk_connector.validate_connector_settings()
|
||||
@@ -89,7 +89,8 @@ def test_run_in_background_and_wait_success() -> None:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result == 42
|
||||
assert elapsed >= 0.1 # Verify we actually waited for the sleep
|
||||
# sometimes slightly flaky
|
||||
assert elapsed >= 0.095 # Verify we actually waited for the sleep
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
|
||||
@@ -5,7 +5,7 @@ envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/con
|
||||
echo "Waiting for API server to boot up; this may take a minute or two..."
|
||||
echo "If this takes more than ~5 minutes, check the logs of the API server container for errors with the following command:"
|
||||
echo
|
||||
echo "docker logs onyx-stack_api_server-1"
|
||||
echo "docker logs onyx-stack-api_server-1"
|
||||
echo
|
||||
|
||||
while true; do
|
||||
|
||||
@@ -129,6 +129,9 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -256,7 +259,7 @@ services:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -325,6 +328,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -357,6 +362,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -434,4 +441,8 @@ volumes:
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -106,6 +106,9 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
# optional, only for debugging purposes
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -211,7 +214,7 @@ services:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -273,6 +276,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -310,6 +315,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -387,4 +394,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -244,8 +244,6 @@ services:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -423,4 +421,3 @@ volumes:
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
|
||||
@@ -54,9 +54,6 @@ services:
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -236,4 +233,3 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
|
||||
@@ -36,6 +36,10 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
# optional, only for debugging purposes
|
||||
- api_server_logs:/var/log
|
||||
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -69,7 +73,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -122,6 +126,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -150,6 +156,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -231,4 +239,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -32,13 +32,14 @@ services:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
@@ -76,7 +77,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -152,6 +153,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -180,6 +183,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -264,4 +269,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- log_store:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
|
||||
1
openapi.json
Normal file
1
openapi.json
Normal file
File diff suppressed because one or more lines are too long
@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
|
||||
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -1079,7 +1079,7 @@ export function AssistantEditor({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<span className="text-sm ml-2">
|
||||
{values.is_public ? "Public" : "Private"}
|
||||
Organization Public
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -1088,17 +1088,22 @@ export function AssistantEditor({
|
||||
<InfoIcon size={16} className="mr-2" />
|
||||
<span className="text-sm">
|
||||
Default persona must be public. Visibility has been
|
||||
automatically set to public.
|
||||
automatically set to organization public.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.is_public ? (
|
||||
<p className="text-sm text-text-dark">
|
||||
Anyone from your team can view and use this assistant
|
||||
This assistant will be available to everyone in your
|
||||
organization
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
<p className="text-sm text-text-dark mb-2">
|
||||
This assistant will only be available to specific
|
||||
users and groups
|
||||
</p>
|
||||
<div className="mt-2">
|
||||
<Label className="mb-2" small>
|
||||
Share with Users and Groups
|
||||
|
||||
@@ -254,14 +254,14 @@ export function SlackChannelConfigFormFields({
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
allowCustomValues={true}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<p className="mt-2 text-sm dark:text-neutral-400 text-neutral-600">
|
||||
Note: This list shows public and private channels where the
|
||||
bot is a member (up to 500 channels). If you don't see a
|
||||
channel, make sure the bot is added to that channel in Slack
|
||||
first, or type the channel name manually.
|
||||
Note: This list shows existing public and private channels (up
|
||||
to 500). You can either select from the list or type any
|
||||
channel name directly.
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -281,7 +281,7 @@ export default function AddConnector({
|
||||
return (
|
||||
<Formik
|
||||
initialValues={{
|
||||
...createConnectorInitialValues(connector),
|
||||
...createConnectorInitialValues(connector, currentCredential),
|
||||
...Object.fromEntries(
|
||||
connectorConfigs[connector].advanced_values.map((field) => [
|
||||
field.name,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -8,10 +7,14 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
|
||||
import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import { Button as TremorButton } from "@/components/ui/button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Credential,
|
||||
GoogleDriveCredentialJson,
|
||||
@@ -20,6 +23,15 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
@@ -31,126 +43,202 @@ export const DriveJsonUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 " +
|
||||
"cursor-pointer bg-backgrournd dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -160,6 +248,7 @@ interface DriveJsonUploadSectionProps {
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const DriveJsonUploadSection = ({
|
||||
@@ -168,6 +257,7 @@ export const DriveJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
existingAuthCredential,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
@@ -177,6 +267,7 @@ export const DriveJsonUploadSection = ({
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountCredentialData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -190,153 +281,135 @@ export const DriveJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/app-credential"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the google drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
</p>
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Google Drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Google Drive, create credentials (either OAuth App or
|
||||
Service Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<a
|
||||
className="text-link"
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/google_drive#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if chooosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/google-drive/service-account-key"
|
||||
: "/api/manage/admin/connector/google-drive/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/credentials"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && <DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -391,6 +464,7 @@ export const DriveAuthSection = ({
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountKeyData
|
||||
);
|
||||
@@ -405,6 +479,7 @@ export const DriveAuthSection = ({
|
||||
setLocalGoogleDriveServiceAccountCredential,
|
||||
] = useState(googleDriveServiceAccountCredential);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountKeyData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -424,126 +499,181 @@ export const DriveAuthSection = ({
|
||||
localGoogleDriveServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Google Drive credentials have been successfully uploaded
|
||||
and authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Google Drive Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string().required(
|
||||
"User email is required"
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<TremorButton type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the docs you have access to in your google drive account.
|
||||
</p>
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Google Drive via OAuth. This
|
||||
gives us read access to the documents you have access to in your
|
||||
Google Drive account.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
if (authUrl) {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
router.push(authUrl);
|
||||
return;
|
||||
}
|
||||
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
router.push(authUrl);
|
||||
} else {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to authenticate with Google Drive - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Authenticate with Google Drive
|
||||
{isAuthenticating
|
||||
? "Authenticating..."
|
||||
: "Authenticate with Google Drive"}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Google Drive
|
||||
Service Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -165,6 +165,10 @@ const GDriveMain = ({
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
googleDrivePublicUploadedCredential ||
|
||||
googleDriveServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin &&
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -8,7 +8,11 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGmailOAuth } from "@/lib/gmail";
|
||||
import { GMAIL_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
@@ -20,10 +24,19 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
const DriveJsonUpload = ({
|
||||
const GmailCredentialUpload = ({
|
||||
setPopup,
|
||||
onSuccess,
|
||||
}: {
|
||||
@@ -31,134 +44,210 @@ const DriveJsonUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 overflow-visible " +
|
||||
"cursor-pointer bg-background dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
// check if the JSON is a app credential or a service account credential
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface DriveJsonUploadSectionProps {
|
||||
interface GmailJsonUploadSectionProps {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
appCredentialData?: { client_id: string };
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const GmailJsonUploadSection = ({
|
||||
@@ -167,7 +256,8 @@ export const GmailJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
existingAuthCredential,
|
||||
}: GmailJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -190,156 +280,138 @@ export const GmailJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-key"
|
||||
);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
</p>
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Gmail, create credentials (either OAuth App or Service
|
||||
Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<a
|
||||
className="text-link"
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/gmail#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a Google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if choosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/gmail/service-account-key"
|
||||
: "/api/manage/admin/connector/gmail/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate("/api/manage/admin/connector/gmail/credentials");
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && (
|
||||
<GmailCredentialUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface DriveCredentialSectionProps {
|
||||
interface GmailCredentialSectionProps {
|
||||
gmailPublicCredential?: Credential<GmailCredentialJson>;
|
||||
gmailServiceAccountCredential?: Credential<GmailServiceAccountCredentialJson>;
|
||||
serviceAccountKeyData?: { service_account_email: string };
|
||||
@@ -387,7 +459,7 @@ export const GmailAuthSection = ({
|
||||
refreshCredentials,
|
||||
connectorExists,
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
}: GmailCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -420,104 +492,141 @@ export const GmailAuthSection = ({
|
||||
localGmailPublicCredential || localGmailServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Gmail credentials have been successfully uploaded and
|
||||
authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Gmail Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Gmail via OAuth. This gives us
|
||||
read access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
@@ -545,7 +654,6 @@ export const GmailAuthSection = ({
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
disabled={isAuthenticating}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
</Button>
|
||||
@@ -553,11 +661,6 @@ export const GmailAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Gmail Service
|
||||
Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -173,6 +173,9 @@ export const GmailMain = () => {
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
gmailPublicUploadedCredential || gmailServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin && hasUploadedCredentials && (
|
||||
|
||||
@@ -100,7 +100,10 @@ export function EmailPasswordForm({
|
||||
// server-side provider values)
|
||||
window.location.href = "/auth/waiting-on-verification";
|
||||
} else {
|
||||
// See above comment
|
||||
// The searchparam is purely for multi tenant developement purposes.
|
||||
// It replicates the behavior of the case where a user
|
||||
// has signed up with email / password as the only user to an instance
|
||||
// and has just completed verification
|
||||
window.location.href = nextUrl
|
||||
? encodeURI(nextUrl)
|
||||
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;
|
||||
|
||||
@@ -7,7 +7,7 @@ import Text from "@/components/ui/text";
|
||||
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
|
||||
import { User } from "@/lib/types";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
export function Verify({ user }: { user: User | null }) {
|
||||
const searchParams = useSearchParams();
|
||||
const router = useRouter();
|
||||
@@ -16,6 +16,8 @@ export function Verify({ user }: { user: User | null }) {
|
||||
|
||||
const verify = useCallback(async () => {
|
||||
const token = searchParams.get("token");
|
||||
const firstUser =
|
||||
searchParams.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
if (!token) {
|
||||
setError(
|
||||
"Missing verification token. Try requesting a new verification email."
|
||||
@@ -35,7 +37,7 @@ export function Verify({ user }: { user: User | null }) {
|
||||
// Use window.location.href to force a full page reload,
|
||||
// ensuring app re-initializes with the new state (including
|
||||
// server-side provider values)
|
||||
window.location.href = "/";
|
||||
window.location.href = firstUser ? "/chat?new_team=true" : "/chat";
|
||||
} else {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
setError(
|
||||
|
||||
@@ -1158,6 +1158,7 @@ export function ChatPage({
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
setUncaughtError(null);
|
||||
setLoadingError(null);
|
||||
|
||||
// Mark that we've sent a message for this session in the current page load
|
||||
markSessionMessageSent(frozenSessionId);
|
||||
|
||||
@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { ChatProvider } from "@/components/context/ChatContext";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
|
||||
export default async function Layout({
|
||||
children,
|
||||
@@ -41,7 +40,6 @@ export default async function Layout({
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<ChatProvider
|
||||
value={{
|
||||
proSearchToggled,
|
||||
|
||||
@@ -54,6 +54,7 @@ export const SourceCard: React.FC<{
|
||||
|
||||
<div className="flex items-center gap-1 mt-1">
|
||||
<ResultIcon doc={document} size={18} />
|
||||
|
||||
<div className="text-text-700 text-xs leading-tight truncate flex-1 min-w-0">
|
||||
{truncatedIdentifier}
|
||||
</div>
|
||||
|
||||
@@ -54,6 +54,7 @@ export function SearchMultiSelectDropdown({
|
||||
onDelete,
|
||||
onSearchTermChange,
|
||||
initialSearchTerm = "",
|
||||
allowCustomValues = false,
|
||||
}: {
|
||||
options: StringOrNumberOption[];
|
||||
onSelect: (selected: StringOrNumberOption) => void;
|
||||
@@ -62,6 +63,7 @@ export function SearchMultiSelectDropdown({
|
||||
onDelete?: (name: string) => void;
|
||||
onSearchTermChange?: (term: string) => void;
|
||||
initialSearchTerm?: string;
|
||||
allowCustomValues?: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState(initialSearchTerm);
|
||||
@@ -77,12 +79,29 @@ export function SearchMultiSelectDropdown({
|
||||
option.name.toLowerCase().includes(searchTerm.toLowerCase())
|
||||
);
|
||||
|
||||
// Handle selecting a custom value not in the options list
|
||||
const handleCustomValueSelect = () => {
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
const customOption: StringOrNumberOption = {
|
||||
name: searchTerm,
|
||||
value: searchTerm,
|
||||
};
|
||||
onSelect(customOption);
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
// If allowCustomValues is enabled and there's text in the search field,
|
||||
// treat clicking outside as selecting the custom value
|
||||
if (allowCustomValues && searchTerm.trim() !== "") {
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
setIsOpen(false);
|
||||
}
|
||||
};
|
||||
@@ -91,7 +110,7 @@ export function SearchMultiSelectDropdown({
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
}, [allowCustomValues, searchTerm]);
|
||||
|
||||
useEffect(() => {
|
||||
setSearchTerm(initialSearchTerm);
|
||||
@@ -102,17 +121,33 @@ export function SearchMultiSelectDropdown({
|
||||
<div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
placeholder={
|
||||
allowCustomValues ? "Search or enter custom value..." : "Search..."
|
||||
}
|
||||
value={searchTerm}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
if (e.target.value) {
|
||||
const newValue = e.target.value;
|
||||
setSearchTerm(newValue);
|
||||
if (onSearchTermChange) {
|
||||
onSearchTermChange(newValue);
|
||||
}
|
||||
if (newValue) {
|
||||
setIsOpen(true);
|
||||
} else {
|
||||
setIsOpen(false);
|
||||
}
|
||||
}}
|
||||
onFocus={() => setIsOpen(true)}
|
||||
onKeyDown={(e) => {
|
||||
if (
|
||||
e.key === "Enter" &&
|
||||
allowCustomValues &&
|
||||
searchTerm.trim() !== ""
|
||||
) {
|
||||
e.preventDefault();
|
||||
handleCustomValueSelect();
|
||||
}
|
||||
}}
|
||||
className="inline-flex justify-between w-full px-4 py-2 text-sm bg-white dark:bg-transparent text-text-800 border border-background-300 rounded-md shadow-sm"
|
||||
/>
|
||||
<button
|
||||
@@ -153,6 +188,22 @@ export function SearchMultiSelectDropdown({
|
||||
)
|
||||
)}
|
||||
|
||||
{allowCustomValues &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
(option) =>
|
||||
option.name.toLowerCase() === searchTerm.toLowerCase()
|
||||
) && (
|
||||
<button
|
||||
className="w-full text-left flex items-center px-4 py-2 text-sm text-text-800 hover:bg-background-100"
|
||||
role="menuitem"
|
||||
onClick={handleCustomValueSelect}
|
||||
>
|
||||
<PlusIcon className="w-4 h-4 mr-2 text-text-600" />
|
||||
Use "{searchTerm}" as custom value
|
||||
</button>
|
||||
)}
|
||||
|
||||
{onCreate &&
|
||||
searchTerm.trim() !== "" &&
|
||||
!filteredOptions.some(
|
||||
@@ -177,7 +228,8 @@ export function SearchMultiSelectDropdown({
|
||||
)}
|
||||
|
||||
{filteredOptions.length === 0 &&
|
||||
(!onCreate || searchTerm.trim() === "") && (
|
||||
((!onCreate && !allowCustomValues) ||
|
||||
searchTerm.trim() === "") && (
|
||||
<div className="px-4 py-2.5 text-sm text-text-500">
|
||||
No matches found
|
||||
</div>
|
||||
|
||||
@@ -49,7 +49,7 @@ export function SearchResultIcon({ url }: { url: string }) {
|
||||
if (!faviconUrl) {
|
||||
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
|
||||
}
|
||||
if (url.includes("docs.onyx.app")) {
|
||||
if (url.includes("onyx.app")) {
|
||||
return <OnyxIcon size={18} className="dark:text-[#fff] text-[#000]" />;
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user