mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-23 10:45:44 +00:00
Compare commits
38 Commits
nik/remove
...
csv_render
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3672b6819 | ||
|
|
921f5d9e96 | ||
|
|
15fe47adc5 | ||
|
|
29958f1a52 | ||
|
|
ac7f9838bc | ||
|
|
d0fa4b3319 | ||
|
|
3fb4fb422e | ||
|
|
ba5da22ea1 | ||
|
|
9909049047 | ||
|
|
c516aa3e3c | ||
|
|
5cc6220417 | ||
|
|
15da1e0a88 | ||
|
|
e9ff00890b | ||
|
|
67747a9d93 | ||
|
|
edfc51b439 | ||
|
|
ac4fba947e | ||
|
|
c142b2db02 | ||
|
|
fb7e7e4395 | ||
|
|
113f23398e | ||
|
|
5a8716026a | ||
|
|
3389140bfd | ||
|
|
13109e7b81 | ||
|
|
56ad457168 | ||
|
|
a81aea2afc | ||
|
|
7cb5c9c4a6 | ||
|
|
3520c58a22 | ||
|
|
bd9d1bfa27 | ||
|
|
14416cc3db | ||
|
|
d7fce14d26 | ||
|
|
39a8d8ed05 | ||
|
|
82f735a434 | ||
|
|
aadb58518b | ||
|
|
0755499e0f | ||
|
|
27aaf977a2 | ||
|
|
9f707f195e | ||
|
|
3e35570f70 | ||
|
|
53b1bf3b2c | ||
|
|
5a3fa6b648 |
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
@@ -45,9 +45,6 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -118,9 +115,10 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -129,7 +127,6 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
|
||||
1
.github/workflows/pr-helm-chart-testing.yml
vendored
1
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -91,6 +91,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -209,6 +266,8 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -216,11 +275,16 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -43,14 +47,27 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
# Process each site
|
||||
enumerate_all = connector_config.get(
|
||||
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
|
||||
)
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
sp_domain_suffix = connector.sharepoint_domain_suffix
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
ctx = connector._create_rest_client_context(site_descriptor.url)
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
|
||||
)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
external_groups = get_sharepoint_external_groups(
|
||||
ctx,
|
||||
connector.graph_client,
|
||||
graph_api_base=connector.graph_api_base,
|
||||
get_access_token=connector._get_graph_access_token,
|
||||
enumerate_all_ad_groups=enumerate_all,
|
||||
)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
@@ -14,7 +18,10 @@ from pydantic import BaseModel
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
|
||||
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -33,6 +40,70 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
|
||||
|
||||
|
||||
def _graph_api_get(
|
||||
url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Authenticated Graph API GET with retry on transient errors."""
|
||||
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
|
||||
access_token = get_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
try:
|
||||
resp = _requests.get(
|
||||
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
if (
|
||||
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
|
||||
and attempt < GRAPH_API_MAX_RETRIES
|
||||
):
|
||||
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
|
||||
logger.warning(
|
||||
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
f"Graph API connection error on attempt {attempt + 1}, "
|
||||
f"retrying in {wait}s: {url}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise RuntimeError(
|
||||
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
|
||||
)
|
||||
|
||||
|
||||
def _iter_graph_collection(
|
||||
initial_url: str,
|
||||
get_access_token: Callable[[], str],
|
||||
params: dict[str, str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Paginate through a Graph API collection, yielding items one at a time."""
|
||||
url: str | None = initial_url
|
||||
while url:
|
||||
data = _graph_api_get(url, get_access_token, params)
|
||||
params = None
|
||||
yield from data.get("value", [])
|
||||
url = data.get("@odata.nextLink")
|
||||
|
||||
|
||||
def _normalize_email(email: str) -> str:
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
return email.replace(MICROSOFT_DOMAIN, "")
|
||||
return email
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
@@ -572,8 +643,65 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
|
||||
|
||||
def _enumerate_ad_groups_paginated(
|
||||
get_access_token: Callable[[], str],
|
||||
already_resolved: set[str],
|
||||
graph_api_base: str,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
|
||||
|
||||
Skips groups whose suffixed name is already in *already_resolved*.
|
||||
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
|
||||
"""
|
||||
groups_url = f"{graph_api_base}/groups"
|
||||
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
|
||||
total_groups = 0
|
||||
|
||||
for group_json in _iter_graph_collection(
|
||||
groups_url, get_access_token, groups_params
|
||||
):
|
||||
group_id: str = group_json.get("id", "")
|
||||
display_name: str = group_json.get("displayName", "")
|
||||
if not group_id or not display_name:
|
||||
continue
|
||||
|
||||
total_groups += 1
|
||||
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
|
||||
"groups — stopping to avoid excessive memory/API usage. "
|
||||
"Remaining groups will be resolved from role assignments only."
|
||||
)
|
||||
return
|
||||
|
||||
name = f"{display_name}_{group_id}"
|
||||
if name in already_resolved:
|
||||
continue
|
||||
|
||||
member_emails: list[str] = []
|
||||
members_url = f"{graph_api_base}/groups/{group_id}/members"
|
||||
members_params: dict[str, str] = {
|
||||
"$select": "userPrincipalName,mail",
|
||||
"$top": "999",
|
||||
}
|
||||
for member_json in _iter_graph_collection(
|
||||
members_url, get_access_token, members_params
|
||||
):
|
||||
email = member_json.get("userPrincipalName") or member_json.get("mail")
|
||||
if email:
|
||||
member_emails.append(_normalize_email(email))
|
||||
|
||||
yield ExternalUserGroup(id=name, user_emails=member_emails)
|
||||
|
||||
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
graph_api_base: str,
|
||||
get_access_token: Callable[[], str] | None = None,
|
||||
enumerate_all_ad_groups: bool = False,
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
@@ -629,57 +757,22 @@ def get_sharepoint_external_groups(
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
external_user_groups: list[ExternalUserGroup] = [
|
||||
ExternalUserGroup(id=group_name, user_emails=list(emails))
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items()
|
||||
]
|
||||
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
if not enumerate_all_ad_groups or get_access_token is None:
|
||||
logger.info(
|
||||
"Skipping exhaustive Azure AD group enumeration. "
|
||||
"Only groups found in site role assignments are included."
|
||||
)
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
return external_user_groups
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
already_resolved = set(groups_and_members.groups_to_emails.keys())
|
||||
for group in _enumerate_ad_groups_paginated(
|
||||
get_access_token, already_resolved, graph_api_base
|
||||
):
|
||||
external_user_groups.append(group)
|
||||
|
||||
return external_user_groups
|
||||
|
||||
@@ -37,12 +37,15 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -1671,7 +1671,10 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
@@ -21,12 +22,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -57,6 +60,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -120,7 +134,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -135,7 +166,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -148,12 +193,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -161,7 +229,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -304,6 +373,12 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
@@ -45,6 +46,7 @@ from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
@@ -422,6 +424,40 @@ def convert_chat_history_basic(
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def _build_tool_call_response_history_message(
|
||||
tool_name: str,
|
||||
generated_images: list[dict] | None,
|
||||
tool_call_response: str | None,
|
||||
) -> str:
|
||||
if tool_name != IMAGE_GENERATION_TOOL_NAME:
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
if generated_images:
|
||||
llm_image_context: list[dict[str, str]] = []
|
||||
for image in generated_images:
|
||||
file_id = image.get("file_id")
|
||||
revised_prompt = image.get("revised_prompt")
|
||||
if not isinstance(file_id, str):
|
||||
continue
|
||||
|
||||
llm_image_context.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"revised_prompt": (
|
||||
revised_prompt if isinstance(revised_prompt, str) else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if llm_image_context:
|
||||
return json.dumps(llm_image_context)
|
||||
|
||||
if tool_call_response:
|
||||
return tool_call_response
|
||||
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
@@ -582,10 +618,24 @@ def convert_chat_history(
|
||||
|
||||
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
|
||||
for tool_call in turn_tool_calls:
|
||||
tool_name = tool_id_to_name_map.get(
|
||||
tool_call.tool_id, "unknown"
|
||||
)
|
||||
tool_response_message = (
|
||||
_build_tool_call_response_history_message(
|
||||
tool_name=tool_name,
|
||||
generated_images=tool_call.generated_images,
|
||||
tool_call_response=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
|
||||
token_count=20, # Tiny overestimate
|
||||
message=tool_response_message,
|
||||
token_count=(
|
||||
token_counter(tool_response_message)
|
||||
if tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
else 20
|
||||
),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
|
||||
@@ -190,7 +190,7 @@ def _build_user_information_section(
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
return USER_INFORMATION_HEADER + "\n".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
@@ -228,23 +228,21 @@ def build_system_prompt(
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
|
||||
if include_all_guidance:
|
||||
system_prompt += (
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
tool_sections = [
|
||||
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
|
||||
INTERNAL_SEARCH_GUIDANCE,
|
||||
WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
),
|
||||
OPEN_URLS_GUIDANCE,
|
||||
PYTHON_TOOL_GUIDANCE,
|
||||
GENERATE_IMAGE_GUIDANCE,
|
||||
MEMORY_GUIDANCE,
|
||||
]
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
|
||||
return system_prompt
|
||||
|
||||
if tools:
|
||||
system_prompt += TOOL_SECTION_HEADER
|
||||
|
||||
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
|
||||
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
|
||||
@@ -254,12 +252,14 @@ def build_system_prompt(
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
tool_guidance_sections: list[str] = []
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
|
||||
|
||||
# These are not included at the Tool level because the ordering may matter.
|
||||
if has_internal_search or include_all_guidance:
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
site_disabled_guidance = ""
|
||||
@@ -269,20 +269,23 @@ def build_system_prompt(
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
tool_guidance_sections.append(
|
||||
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
|
||||
|
||||
if has_python or include_all_guidance:
|
||||
system_prompt += PYTHON_TOOL_GUIDANCE
|
||||
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
|
||||
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
tool_guidance_sections.append(MEMORY_GUIDANCE)
|
||||
|
||||
if tool_guidance_sections:
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -251,7 +251,9 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -282,6 +284,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
# in case that doesn't work for whatever reason.
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
@@ -637,6 +642,14 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
|
||||
# When False (default), only groups found in site role assignments are synced.
|
||||
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
|
||||
# connector_specific_config.
|
||||
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
|
||||
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
|
||||
)
|
||||
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = int(
|
||||
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
@@ -157,6 +157,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -443,6 +454,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -83,7 +83,11 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com"
|
||||
|
||||
GRAPH_API_BASE = f"{DEFAULT_GRAPH_API_HOST}/v1.0"
|
||||
GRAPH_API_MAX_RETRIES = 5
|
||||
GRAPH_API_RETRYABLE_STATUSES = frozenset({429, 500, 502, 503, 504})
|
||||
|
||||
@@ -176,6 +180,25 @@ class CertificateData(BaseModel):
|
||||
thumbprint: str
|
||||
|
||||
|
||||
def _site_page_in_time_window(
|
||||
page: dict[str, Any],
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
) -> bool:
|
||||
"""Return True if the page's lastModifiedDateTime falls within [start, end]."""
|
||||
if start is None and end is None:
|
||||
return True
|
||||
raw = page.get("lastModifiedDateTime")
|
||||
if not raw:
|
||||
return True
|
||||
if not isinstance(raw, str):
|
||||
raise ValueError(f"lastModifiedDateTime is not a string: {raw}")
|
||||
last_modified = datetime.fromisoformat(raw.replace("Z", "+00:00"))
|
||||
return (start is None or last_modified >= start) and (
|
||||
end is None or last_modified <= end
|
||||
)
|
||||
|
||||
|
||||
def sleep_and_retry(
|
||||
query_obj: ClientQuery, method_name: str, max_retries: int = 3
|
||||
) -> Any:
|
||||
@@ -221,6 +244,12 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
current_drive_name: str | None = None
|
||||
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
|
||||
current_drive_web_url: str | None = None
|
||||
# Resolved drive ID — avoids re-resolving on checkpoint resume
|
||||
current_drive_id: str | None = None
|
||||
# Next delta API page URL for per-page checkpointing within a drive.
|
||||
# When set, Phase 3b fetches one page at a time so progress is persisted
|
||||
# between pages. None means BFS path or no active delta traversal.
|
||||
current_drive_delta_next_link: str | None = None
|
||||
|
||||
process_site_pages: bool = False
|
||||
|
||||
@@ -266,10 +295,12 @@ def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData
|
||||
|
||||
|
||||
def acquire_token_for_rest(
|
||||
msal_app: msal.ConfidentialClientApplication, sp_tenant_domain: str
|
||||
msal_app: msal.ConfidentialClientApplication,
|
||||
sp_tenant_domain: str,
|
||||
sharepoint_domain_suffix: str,
|
||||
) -> TokenResponse:
|
||||
token = msal_app.acquire_token_for_client(
|
||||
scopes=[f"https://{sp_tenant_domain}.sharepoint.com/.default"]
|
||||
scopes=[f"https://{sp_tenant_domain}.{sharepoint_domain_suffix}/.default"]
|
||||
)
|
||||
return TokenResponse.from_json(token)
|
||||
|
||||
@@ -384,12 +415,13 @@ def _download_via_graph_api(
|
||||
drive_id: str,
|
||||
item_id: str,
|
||||
bytes_allowed: int,
|
||||
graph_api_base: str,
|
||||
) -> bytes:
|
||||
"""Download a drive item via the Graph API /content endpoint with a byte cap.
|
||||
|
||||
Raises SizeCapExceeded if the cap is exceeded.
|
||||
"""
|
||||
url = f"{GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}/content"
|
||||
url = f"{graph_api_base}/drives/{drive_id}/items/{item_id}/content"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
with requests.get(
|
||||
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
@@ -410,6 +442,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
drive_name: str,
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
graph_api_base: str,
|
||||
include_permissions: bool = False,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
access_token: str | None = None,
|
||||
@@ -466,6 +499,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
driveitem.drive_id,
|
||||
driveitem.id,
|
||||
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD,
|
||||
graph_api_base=graph_api_base,
|
||||
)
|
||||
except SizeCapExceeded:
|
||||
logger.warning(
|
||||
@@ -785,6 +819,9 @@ class SharepointConnector(
|
||||
sites: list[str] = [],
|
||||
include_site_pages: bool = True,
|
||||
include_site_documents: bool = True,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sites = list(sites)
|
||||
@@ -800,6 +837,10 @@ class SharepointConnector(
|
||||
self._cached_rest_ctx: ClientContext | None = None
|
||||
self._cached_rest_ctx_url: str | None = None
|
||||
self._cached_rest_ctx_created_at: float = 0.0
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
self.graph_api_base = f"{self.graph_api_host}/v1.0"
|
||||
self.sharepoint_domain_suffix = sharepoint_domain_suffix
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Validate that at least one content type is enabled
|
||||
@@ -856,8 +897,9 @@ class SharepointConnector(
|
||||
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
sp_domain_suffix = self.sharepoint_domain_suffix
|
||||
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
|
||||
)
|
||||
self._cached_rest_ctx_url = site_url
|
||||
self._cached_rest_ctx_created_at = time.monotonic()
|
||||
@@ -1117,76 +1159,36 @@ class SharepointConnector(
|
||||
site_descriptor: SiteDescriptor,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch SharePoint site pages (.aspx files) using the SharePoint Pages API."""
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Yield SharePoint site pages (.aspx files) one at a time.
|
||||
|
||||
# Get the site to extract the site ID
|
||||
Pages are fetched via the Graph Pages API and yielded lazily as each
|
||||
API page arrives, so memory stays bounded regardless of total page count.
|
||||
Time-window filtering is applied per-item before yielding.
|
||||
"""
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
site.execute_query() # Execute the query to actually fetch the data
|
||||
site.execute_query()
|
||||
site_id = site.id
|
||||
|
||||
# Get the token acquisition function from the GraphClient
|
||||
token_data = self._acquire_token()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise RuntimeError("Failed to acquire access token")
|
||||
|
||||
# Construct the SharePoint Pages API endpoint
|
||||
# Using API directly, since the Graph Client doesn't support the Pages API
|
||||
pages_endpoint = f"https://graph.microsoft.com/v1.0/sites/{site_id}/pages/microsoft.graph.sitePage"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add expand parameter to get canvas layout content
|
||||
params = {"$expand": "canvasLayout"}
|
||||
|
||||
response = requests.get(
|
||||
pages_endpoint,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=REQUEST_TIMEOUT_SECONDS,
|
||||
page_url: str | None = (
|
||||
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pages_data = response.json()
|
||||
all_pages = pages_data.get("value", [])
|
||||
params: dict[str, str] | None = {"$expand": "canvasLayout"}
|
||||
total_yielded = 0
|
||||
|
||||
# Handle pagination if there are more pages
|
||||
# TODO: This accumulates all pages in memory and can be heavy on large tenants.
|
||||
# We should process each page incrementally to avoid unbounded growth.
|
||||
while "@odata.nextLink" in pages_data:
|
||||
next_url = pages_data["@odata.nextLink"]
|
||||
response = requests.get(
|
||||
next_url, headers=headers, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
)
|
||||
response.raise_for_status()
|
||||
pages_data = response.json()
|
||||
all_pages.extend(pages_data.get("value", []))
|
||||
while page_url:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
params = None # nextLink already embeds query params
|
||||
|
||||
logger.debug(f"Found {len(all_pages)} site pages in {site_descriptor.url}")
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
total_yielded += 1
|
||||
yield page
|
||||
|
||||
# Filter pages based on time window if specified
|
||||
if start is not None or end is not None:
|
||||
filtered_pages: list[dict[str, Any]] = []
|
||||
for page in all_pages:
|
||||
page_modified = page.get("lastModifiedDateTime")
|
||||
if page_modified:
|
||||
if isinstance(page_modified, str):
|
||||
page_modified = datetime.fromisoformat(
|
||||
page_modified.replace("Z", "+00:00")
|
||||
)
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
if start is not None and page_modified < start:
|
||||
continue
|
||||
if end is not None and page_modified > end:
|
||||
continue
|
||||
|
||||
filtered_pages.append(page)
|
||||
all_pages = filtered_pages
|
||||
|
||||
return all_pages
|
||||
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
|
||||
|
||||
def _acquire_token(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -1196,7 +1198,7 @@ class SharepointConnector(
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
scopes=[f"{self.graph_api_host}/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -1269,9 +1271,10 @@ class SharepointConnector(
|
||||
Performs BFS folder traversal manually, fetching one page of children
|
||||
at a time so that memory usage stays bounded regardless of drive size.
|
||||
"""
|
||||
base = f"{GRAPH_API_BASE}/drives/{drive_id}"
|
||||
base = f"{self.graph_api_base}/drives/{drive_id}"
|
||||
if folder_path:
|
||||
start_url = f"{base}/root:/{folder_path}:/children"
|
||||
encoded_path = quote(folder_path, safe="/")
|
||||
start_url = f"{base}/root:/{encoded_path}:/children"
|
||||
else:
|
||||
start_url = f"{base}/root/children"
|
||||
|
||||
@@ -1329,7 +1332,7 @@ class SharepointConnector(
|
||||
"""
|
||||
use_timestamp_token = start is not None and start > _EPOCH
|
||||
|
||||
initial_url = f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta"
|
||||
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
if use_timestamp_token:
|
||||
assert start is not None # mypy
|
||||
token = quote(start.isoformat(timespec="seconds"))
|
||||
@@ -1375,7 +1378,7 @@ class SharepointConnector(
|
||||
drive_id,
|
||||
)
|
||||
yield from self._iter_delta_pages(
|
||||
initial_url=f"{GRAPH_API_BASE}/drives/{drive_id}/root/delta",
|
||||
initial_url=f"{self.graph_api_base}/drives/{drive_id}/root/delta",
|
||||
drive_id=drive_id,
|
||||
start=start,
|
||||
end=end,
|
||||
@@ -1406,6 +1409,87 @@ class SharepointConnector(
|
||||
if not page_url:
|
||||
break
|
||||
|
||||
def _build_delta_start_url(
|
||||
self,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> str:
|
||||
"""Build the initial delta API URL with query parameters embedded.
|
||||
|
||||
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
|
||||
URL so that the returned string is fully self-contained and can be
|
||||
stored in a checkpoint without needing a separate params dict.
|
||||
"""
|
||||
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
params = [f"$top={page_size}"]
|
||||
if start is not None and start > _EPOCH:
|
||||
token = quote(start.isoformat(timespec="seconds"))
|
||||
params.append(f"token={token}")
|
||||
return f"{base_url}?{'&'.join(params)}"
|
||||
|
||||
def _fetch_one_delta_page(
|
||||
self,
|
||||
page_url: str,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
"""Fetch a single page of delta API results.
|
||||
|
||||
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
|
||||
the delta enumeration is complete (deltaLink with no nextLink).
|
||||
|
||||
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
|
||||
the caller can store the resync URL in the checkpoint and retry on
|
||||
the next cycle.
|
||||
"""
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except requests.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 410:
|
||||
logger.warning(
|
||||
"Delta token expired (410 Gone) for drive '%s'. "
|
||||
"Will restart with full delta enumeration.",
|
||||
drive_id,
|
||||
)
|
||||
full_url = (
|
||||
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
f"?$top={page_size}"
|
||||
)
|
||||
return [], full_url
|
||||
raise
|
||||
|
||||
items: list[DriveItemData] = []
|
||||
for item in data.get("value", []):
|
||||
if "folder" in item or "deleted" in item:
|
||||
continue
|
||||
if start is not None or end is not None:
|
||||
raw_ts = item.get("lastModifiedDateTime")
|
||||
if raw_ts:
|
||||
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
|
||||
if start is not None and mod_dt < start:
|
||||
continue
|
||||
if end is not None and mod_dt > end:
|
||||
continue
|
||||
items.append(DriveItemData.from_graph_json(item))
|
||||
|
||||
next_url = data.get("@odata.nextLink")
|
||||
if next_url:
|
||||
return items, next_url
|
||||
return items, None
|
||||
|
||||
@staticmethod
|
||||
def _clear_drive_checkpoint_state(
|
||||
checkpoint: "SharepointConnectorCheckpoint",
|
||||
) -> None:
|
||||
"""Reset all drive-level fields in the checkpoint."""
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_id = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
|
||||
@@ -1492,7 +1576,7 @@ class SharepointConnector(
|
||||
sp_private_key = credentials.get("sp_private_key")
|
||||
sp_certificate_password = credentials.get("sp_certificate_password")
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
|
||||
authority_url = f"{self.authority_host}/{sp_directory_id}"
|
||||
|
||||
if auth_method == SharepointAuthMethod.CERTIFICATE.value:
|
||||
logger.info("Using certificate authentication")
|
||||
@@ -1533,7 +1617,7 @@ class SharepointConnector(
|
||||
raise ConnectorValidationError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
scopes=[f"{self.graph_api_host}/.default"]
|
||||
)
|
||||
if token is None:
|
||||
raise ConnectorValidationError("Failed to acquire token for graph")
|
||||
@@ -1847,14 +1931,13 @@ class SharepointConnector(
|
||||
# Return checkpoint to allow persistence after drive initialization
|
||||
return checkpoint
|
||||
|
||||
# Phase 3: Process documents from current drive
|
||||
# Phase 3a: Initialize the next drive for processing
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.cached_drive_names
|
||||
and len(checkpoint.cached_drive_names) > 0
|
||||
and checkpoint.current_drive_name is None
|
||||
):
|
||||
|
||||
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
|
||||
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
@@ -1862,7 +1945,8 @@ class SharepointConnector(
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
|
||||
logger.info(
|
||||
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
|
||||
f"Processing drive '{checkpoint.current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}"
|
||||
)
|
||||
logger.debug(f"Time range: {start_dt} to {end_dt}")
|
||||
|
||||
@@ -1871,35 +1955,35 @@ class SharepointConnector(
|
||||
logger.warning("Current drive name is None, skipping")
|
||||
return checkpoint
|
||||
|
||||
driveitems: Iterable[DriveItemData] = iter(())
|
||||
drive_web_url: str | None = None
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching drive items for drive name: {current_drive_name}"
|
||||
)
|
||||
result = self._resolve_drive(site_descriptor, current_drive_name)
|
||||
if result is not None:
|
||||
drive_id, drive_web_url = result
|
||||
driveitems = self._get_drive_items_for_drive_id(
|
||||
site_descriptor, drive_id, start_dt, end_dt
|
||||
)
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
if result is None:
|
||||
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
drive_id, drive_web_url = result
|
||||
checkpoint.current_drive_id = drive_id
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}: {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
|
||||
f"Failed to access drive '{current_drive_name}' "
|
||||
f"in site '{site_descriptor.url}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
display_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
current_drive_name, current_drive_name
|
||||
)
|
||||
|
||||
@@ -1907,10 +1991,74 @@ class SharepointConnector(
|
||||
yield from self._yield_drive_hierarchy_node(
|
||||
site_descriptor.url,
|
||||
drive_web_url,
|
||||
current_drive_name,
|
||||
display_drive_name,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
# For non-folder-scoped drives, use delta API with per-page
|
||||
# checkpointing. Build the initial URL and fall through to 3b.
|
||||
if not site_descriptor.folder_path:
|
||||
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
|
||||
drive_id, start_dt
|
||||
)
|
||||
# else: BFS path — delta_next_link stays None;
|
||||
# Phase 3b will use _iter_drive_items_paged.
|
||||
|
||||
# Phase 3b: Process items from the current drive
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.current_drive_name is not None
|
||||
and checkpoint.current_drive_id is not None
|
||||
):
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
checkpoint.current_drive_name, checkpoint.current_drive_name
|
||||
)
|
||||
drive_web_url = checkpoint.current_drive_web_url
|
||||
|
||||
# --- determine item source ---
|
||||
driveitems: Iterable[DriveItemData]
|
||||
has_more_delta_pages = False
|
||||
|
||||
if checkpoint.current_drive_delta_next_link:
|
||||
# Delta path: fetch one page at a time for checkpointing
|
||||
try:
|
||||
page_items, next_url = self._fetch_one_delta_page(
|
||||
page_url=checkpoint.current_drive_delta_next_link,
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
driveitems = page_items
|
||||
has_more_delta_pages = next_url is not None
|
||||
if next_url:
|
||||
checkpoint.current_drive_delta_next_link = next_url
|
||||
else:
|
||||
# BFS path (folder-scoped): process all items at once
|
||||
driveitems = self._iter_drive_items_paged(
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
folder_path=site_descriptor.folder_path,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
@@ -1952,8 +2100,6 @@ class SharepointConnector(
|
||||
if include_permissions:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
# Re-acquire token in case it expired during a long traversal
|
||||
# MSAL has a cache that returns the same token while still valid.
|
||||
access_token = self._get_graph_access_token()
|
||||
doc_or_failure = _convert_driveitem_to_document_with_permissions(
|
||||
driveitem,
|
||||
@@ -1962,6 +2108,7 @@ class SharepointConnector(
|
||||
self.graph_client,
|
||||
include_permissions=include_permissions,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
graph_api_base=self.graph_api_base,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
@@ -1988,8 +2135,11 @@ class SharepointConnector(
|
||||
)
|
||||
|
||||
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
|
||||
if has_more_delta_pages:
|
||||
return checkpoint
|
||||
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
|
||||
# Phase 4: Progression logic - determine next step
|
||||
# If we have more drives in current site, continue with current site
|
||||
|
||||
@@ -50,12 +50,15 @@ class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
|
||||
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
|
||||
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -63,11 +66,15 @@ class TeamsConnector(
|
||||
# are not necessarily guaranteed to be unique
|
||||
teams: list[str] = [],
|
||||
max_workers: int = MAX_WORKERS,
|
||||
authority_host: str = DEFAULT_AUTHORITY_HOST,
|
||||
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
|
||||
) -> None:
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
self.authority_host = authority_host.rstrip("/")
|
||||
self.graph_api_host = graph_api_host.rstrip("/")
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
@@ -76,7 +83,7 @@ class TeamsConnector(
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
|
||||
authority_url = f"{self.authority_host}/{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
@@ -91,7 +98,7 @@ class TeamsConnector(
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
scopes=[f"{self.graph_api_host}/.default"]
|
||||
)
|
||||
|
||||
if not isinstance(token, dict):
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
@@ -905,13 +906,15 @@ def convert_slack_score(slack_score: float) -> float:
|
||||
def slack_retrieval(
|
||||
query: ChunkIndexRequest,
|
||||
access_token: str,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
|
||||
entities: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
team_id: str | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB query (no session needed)
|
||||
search_settings: SearchSettings | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Main entry point for Slack federated search with entity filtering.
|
||||
@@ -925,7 +928,7 @@ def slack_retrieval(
|
||||
Args:
|
||||
query: Search query object
|
||||
access_token: User OAuth access token
|
||||
db_session: Database session
|
||||
db_session: Database session (optional if search_settings provided)
|
||||
connector: Federated connector detail (unused, kept for backwards compat)
|
||||
entities: Connector-level config (entity filtering configuration)
|
||||
limit: Maximum number of results
|
||||
@@ -1153,7 +1156,10 @@ def slack_retrieval(
|
||||
|
||||
# chunk index docs into doc aware chunks
|
||||
# a single index doc can get split into multiple chunks
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or search_settings must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedder = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
@@ -18,8 +18,10 @@ from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,7 +43,7 @@ def _build_index_filters(
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
query: str | None = None,
|
||||
llm: LLM | None = None,
|
||||
@@ -49,6 +51,8 @@ def _build_index_filters(
|
||||
# Assistant knowledge filters
|
||||
attached_document_ids: list[str] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
# Pre-fetched ACL filters (skips DB query when provided)
|
||||
acl_filters: list[str] | None = None,
|
||||
) -> IndexFilters:
|
||||
if auto_detect_filters and (llm is None or query is None):
|
||||
raise RuntimeError("LLM and query are required for auto detect filters")
|
||||
@@ -103,9 +107,14 @@ def _build_index_filters(
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
user_acl_filters = acl_filters
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or acl_filters must be provided")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
@@ -252,11 +261,15 @@ def search_pipeline(
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
@@ -297,6 +310,7 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
acl_filters=acl_filters,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
@@ -315,6 +329,8 @@ def search_pipeline(
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -14,9 +14,11 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -50,9 +52,14 @@ def combine_retrieval_results(
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_query_embedding(query_request.query, db_session)
|
||||
query_embedding = get_query_embedding(
|
||||
query_request.query,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
|
||||
|
||||
@@ -78,7 +85,9 @@ def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -88,14 +97,22 @@ def search_chunks(
|
||||
else None
|
||||
)
|
||||
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
# Federated retrieval — use pre-fetched if available, otherwise query DB
|
||||
if prefetched_federated_retrieval_infos is not None:
|
||||
federated_retrieval_infos = prefetched_federated_retrieval_infos
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError(
|
||||
"Either db_session or prefetched_federated_retrieval_infos "
|
||||
"must be provided"
|
||||
)
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -114,7 +131,10 @@ def search_chunks(
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(_embed_and_search, (query_request, document_index, db_session))
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
|
||||
@@ -64,23 +64,34 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
def get_query_embeddings(
|
||||
queries: list[str],
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[Embedding]:
|
||||
if embedding_model is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or embedding_model must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return get_query_embeddings([query], db_session)[0]
|
||||
def get_query_embedding(
|
||||
query: str,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> Embedding:
|
||||
return get_query_embeddings(
|
||||
[query], db_session=db_session, embedding_model=embedding_model
|
||||
)[0]
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -116,12 +116,15 @@ def get_connector_credential_pairs_for_user(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""Get connector credential pairs for a user.
|
||||
|
||||
Args:
|
||||
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
|
||||
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
|
||||
defer_connector_config: If True, skips loading Connector.connector_specific_config
|
||||
to avoid fetching large JSONB blobs when they aren't needed.
|
||||
"""
|
||||
if eager_load_user:
|
||||
assert (
|
||||
@@ -130,7 +133,10 @@ def get_connector_credential_pairs_for_user(
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
connector_load = selectinload(ConnectorCredentialPair.connector)
|
||||
if defer_connector_config:
|
||||
connector_load = connector_load.defer(Connector.connector_specific_config)
|
||||
stmt = stmt.options(connector_load)
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
@@ -170,6 +176,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
@@ -183,6 +190,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc=order_by_desc,
|
||||
source=source,
|
||||
processing_mode=processing_mode,
|
||||
defer_connector_config=defer_connector_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -554,10 +554,19 @@ def fetch_all_document_sets_for_user(
|
||||
stmt = (
|
||||
select(DocumentSetDBModel)
|
||||
.distinct()
|
||||
.options(selectinload(DocumentSetDBModel.federated_connectors))
|
||||
.options(
|
||||
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSetDBModel.users),
|
||||
selectinload(DocumentSetDBModel.groups),
|
||||
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
return db_session.scalars(stmt).all()
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_documents_for_document_set_paginated(
|
||||
|
||||
@@ -287,7 +287,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
"Credential", back_populates="user"
|
||||
)
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
@@ -321,7 +321,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
"Memory",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
order_by="desc(Memory.id)",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
@@ -4940,6 +4939,7 @@ class ScimUserMapping(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -31,55 +32,61 @@ async def fetch_user_for_pat(
|
||||
|
||||
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
|
||||
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
|
||||
|
||||
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
|
||||
are loaded correctly — matching the pattern in fetch_user_for_api_key.
|
||||
"""
|
||||
# Single joined query with all filters pushed to database
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await async_db_session.execute(
|
||||
select(PersonalAccessToken, User)
|
||||
.join(User, PersonalAccessToken.user_id == User.id)
|
||||
|
||||
user = await async_db_session.scalar(
|
||||
select(User)
|
||||
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.where(User.is_active) # type: ignore
|
||||
.where(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.limit(1)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
if not user:
|
||||
return None
|
||||
|
||||
pat, user = row
|
||||
|
||||
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
|
||||
# For request-level auditing, use application logs or a dedicated audit table
|
||||
should_update = (
|
||||
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
|
||||
)
|
||||
|
||||
if should_update:
|
||||
# Update in separate session to avoid transaction coupling (fire-and-forget)
|
||||
async def _update_last_used() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(
|
||||
tenant_id
|
||||
) as separate_session:
|
||||
await separate_session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await separate_session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
|
||||
asyncio.create_task(_update_last_used())
|
||||
|
||||
_schedule_pat_last_used_update(hashed_token, now)
|
||||
return user
|
||||
|
||||
|
||||
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
|
||||
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
|
||||
|
||||
async def _update() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(tenant_id) as session:
|
||||
pat = await session.scalar(
|
||||
select(PersonalAccessToken).where(
|
||||
PersonalAccessToken.hashed_token == hashed_token
|
||||
)
|
||||
)
|
||||
if not pat:
|
||||
return
|
||||
if (
|
||||
pat.last_used_at is not None
|
||||
and (now - pat.last_used_at).total_seconds() <= 300
|
||||
):
|
||||
return
|
||||
await session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
|
||||
asyncio.create_task(_update())
|
||||
|
||||
|
||||
def create_pat(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
@@ -420,9 +421,16 @@ def get_minimal_persona_snapshots_for_user(
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
@@ -453,7 +461,16 @@ def get_persona_snapshots_for_user(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
@@ -550,9 +567,16 @@ def get_minimal_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
)
|
||||
|
||||
@@ -611,7 +635,16 @@ def get_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
|
||||
@@ -20,7 +20,20 @@ class ImageGenerationProviderCredentials(BaseModel):
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
|
||||
class ReferenceImage(BaseModel):
|
||||
data: bytes
|
||||
mime_type: str
|
||||
|
||||
|
||||
class ImageGenerationProvider(abc.ABC):
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate_credentials(
|
||||
@@ -63,6 +76,7 @@ class ImageGenerationProvider(abc.ABC):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Generates an image based on a prompt."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -59,6 +60,7 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -45,6 +46,7 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -9,6 +11,7 @@ from pydantic import BaseModel
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -51,6 +54,15 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# Gemini image editing supports up to 14 input images.
|
||||
return 14
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -58,8 +70,18 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
if reference_images:
|
||||
return self._generate_image_with_reference_images(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
reference_images=reference_images,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
@@ -74,6 +96,99 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _generate_image_with_reference_images(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
reference_images: list[ReferenceImage],
|
||||
) -> ImageGenerationResponse:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
from google.oauth2 import service_account
|
||||
from litellm.types.utils import ImageObject
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
service_account_info = json.loads(self._vertex_credentials)
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
service_account_info,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self._vertex_project,
|
||||
location=self._vertex_location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
parts: list[genai_types.Part] = [
|
||||
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
|
||||
for image in reference_images
|
||||
]
|
||||
parts.append(genai_types.Part.from_text(text=prompt))
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
response_modalities=["TEXT", "IMAGE"],
|
||||
candidate_count=max(1, n),
|
||||
image_config=genai_types.ImageConfig(
|
||||
aspect_ratio=_map_size_to_aspect_ratio(size)
|
||||
),
|
||||
)
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
response = client.models.generate_content(
|
||||
model=model_name,
|
||||
contents=genai_types.Content(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
config=config,
|
||||
)
|
||||
|
||||
generated_data: list[ImageObject] = []
|
||||
for candidate in response.candidates or []:
|
||||
candidate_content = candidate.content
|
||||
if not candidate_content:
|
||||
continue
|
||||
|
||||
for part in candidate_content.parts or []:
|
||||
inline_data = part.inline_data
|
||||
if not inline_data or inline_data.data is None:
|
||||
continue
|
||||
|
||||
if isinstance(inline_data.data, bytes):
|
||||
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
|
||||
elif isinstance(inline_data.data, str):
|
||||
b64_json = inline_data.data
|
||||
else:
|
||||
continue
|
||||
|
||||
generated_data.append(
|
||||
ImageObject(
|
||||
b64_json=b64_json,
|
||||
revised_prompt=prompt,
|
||||
)
|
||||
)
|
||||
|
||||
if not generated_data:
|
||||
raise RuntimeError("No image data returned from Vertex AI.")
|
||||
|
||||
return ImageResponse(
|
||||
created=int(datetime.now().timestamp()),
|
||||
data=generated_data,
|
||||
)
|
||||
|
||||
|
||||
def _map_size_to_aspect_ratio(size: str) -> str:
|
||||
return {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1536x1024": "3:2",
|
||||
"1024x1536": "2:3",
|
||||
}.get(size, "1:1")
|
||||
|
||||
|
||||
def _parse_to_vertex_credentials(
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
|
||||
@@ -64,21 +64,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -159,11 +144,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1320,11 +1300,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1365,16 +1340,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1505,26 +1470,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-7-sonnet-20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"claude-3-7-sonnet-latest": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-7-sonnet@20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"claude-3-haiku-20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"claude-4-opus-20250514": {
|
||||
"display_name": "Claude Opus 4",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1705,16 +1650,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3226,15 +3161,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku-20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-sonnet": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic"
|
||||
@@ -3249,16 +3175,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3.7-sonnet": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3.7-sonnet:beta": {
|
||||
"display_name": "Claude Sonnet 3.7:beta",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-haiku-4.5": {
|
||||
"display_name": "Claude Haiku 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3750,16 +3666,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"us.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3879,20 +3785,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620"
|
||||
},
|
||||
"vertex_ai/claude-3-7-sonnet@20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku@20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"vertex_ai/claude-3-sonnet": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic"
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -23,6 +25,11 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
|
||||
_recommendations_cache_lock = threading.Lock()
|
||||
_cached_recommendations: LLMRecommendations | None = None
|
||||
_cached_recommendations_time: float = 0.0
|
||||
|
||||
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level.
|
||||
@@ -41,19 +48,40 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
}
|
||||
|
||||
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations from the GitHub config."""
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
if recommendations_from_github:
|
||||
return recommendations_from_github
|
||||
|
||||
# Fall back to json bundled with code
|
||||
def _load_bundled_recommendations() -> LLMRecommendations:
|
||||
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
|
||||
with open(json_path, "r") as f:
|
||||
json_config = json.load(f)
|
||||
return LLMRecommendations.model_validate(json_config)
|
||||
|
||||
recommendations_from_json = LLMRecommendations.model_validate(json_config)
|
||||
return recommendations_from_json
|
||||
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations, with an in-memory cache to avoid
|
||||
hitting GitHub on every API request."""
|
||||
global _cached_recommendations, _cached_recommendations_time
|
||||
|
||||
now = time.monotonic()
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
with _recommendations_cache_lock:
|
||||
# Double-check after acquiring lock
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (time.monotonic() - _cached_recommendations_time)
|
||||
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
result = recommendations_from_github or _load_bundled_recommendations()
|
||||
|
||||
_cached_recommendations = result
|
||||
_cached_recommendations_time = time.monotonic()
|
||||
return result
|
||||
|
||||
|
||||
def is_obsolete_model(model_name: str, provider: str) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
# If there are any tools, this section is included, the sections below are for the available tools
|
||||
TOOL_SECTION_HEADER = "\n\n# Tools\n"
|
||||
TOOL_SECTION_HEADER = "\n# Tools\n\n"
|
||||
|
||||
|
||||
# This section is included if there are search type tools, currently internal_search and web_search
|
||||
@@ -16,11 +16,10 @@ When searching for information, if the initial results cannot fully answer the u
|
||||
Do not repeat the same or very similar queries if it already has been run in the chat history.
|
||||
|
||||
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
|
||||
INTERNAL_SEARCH_GUIDANCE = """
|
||||
|
||||
## internal_search
|
||||
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
|
||||
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
|
||||
@@ -28,34 +27,31 @@ Use the `internal_search` tool to search connected applications for information.
|
||||
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
|
||||
- Ambiguity: questions about something that is not widely known or understood.
|
||||
Never provide more than 3 queries at once to `internal_search`.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
|
||||
WEB_SEARCH_GUIDANCE = """
|
||||
|
||||
## web_search
|
||||
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
|
||||
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
|
||||
- Accuracy: if the cost of outdated/inaccurate information is high.
|
||||
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
|
||||
Do not use the "site:" operator in your web search queries.
|
||||
""".rstrip()
|
||||
""".lstrip()
|
||||
|
||||
|
||||
OPEN_URLS_GUIDANCE = """
|
||||
|
||||
## open_url
|
||||
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
|
||||
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
|
||||
Do not open URLs that are image files like .png, .jpg, etc.
|
||||
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
PYTHON_TOOL_GUIDANCE = """
|
||||
|
||||
## python
|
||||
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
|
||||
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
|
||||
@@ -64,21 +60,21 @@ Use this to give the user a way to download the file OR to display generated ima
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
|
||||
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
"""
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
## add_memory
|
||||
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
|
||||
Only add memories that are specific, likely to remain true, and likely to be useful later. \
|
||||
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
TOOL_CALL_FAILURE_PROMPT = """
|
||||
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.
|
||||
|
||||
@@ -1,40 +1,36 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
USER_INFORMATION_HEADER = "\n\n# User Information\n"
|
||||
USER_INFORMATION_HEADER = "\n# User Information\n\n"
|
||||
|
||||
BASIC_INFORMATION_PROMPT = """
|
||||
|
||||
## Basic Information
|
||||
User name: {user_name}
|
||||
User email: {user_email}{user_role}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# This line only shows up if the user has configured their role.
|
||||
USER_ROLE_PROMPT = """
|
||||
User role: {user_role}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# Team information should be a paragraph style description of the user's team.
|
||||
TEAM_INFORMATION_PROMPT = """
|
||||
|
||||
## Team Information
|
||||
{team_information}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# User preferences should be a paragraph style description of the user's preferences.
|
||||
USER_PREFERENCES_PROMPT = """
|
||||
|
||||
## User Preferences
|
||||
{user_preferences}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# User memories should look something like:
|
||||
# - Memory 1
|
||||
# - Memory 2
|
||||
# - Memory 3
|
||||
USER_MEMORIES_PROMPT = """
|
||||
|
||||
## User Memories
|
||||
{user_memories}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -103,6 +103,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.federated import fetch_all_federated_connectors_parallel
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
@@ -987,6 +988,7 @@ def get_connector_status(
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
eager_load_user=True,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
@@ -1000,11 +1002,23 @@ def get_connector_status(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
|
||||
connector_to_credential_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
|
||||
cc_pair.credential_id
|
||||
)
|
||||
|
||||
return [
|
||||
ConnectorStatus(
|
||||
cc_pair_id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector,
|
||||
credential_ids=connector_to_credential_ids.get(
|
||||
cc_pair.connector_id, []
|
||||
),
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
|
||||
access_type=cc_pair.access_type,
|
||||
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
|
||||
@@ -1059,15 +1073,27 @@ def get_connector_indexing_status(
|
||||
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
|
||||
# Get editable connector/credential pairs
|
||||
(
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, True, None, True, True, True, True, request.source),
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, True, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get federated connectors
|
||||
(fetch_all_federated_connectors_parallel, ()),
|
||||
# Get most recent index attempts
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, False
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get most recent finished index attempts
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, True
|
||||
),
|
||||
(),
|
||||
),
|
||||
]
|
||||
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
@@ -1084,8 +1110,10 @@ def get_connector_indexing_status(
|
||||
parallel_functions.append(
|
||||
# Get non-editable connector/credential pairs
|
||||
(
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, False, None, True, True, True, True, request.source),
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, False, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1911,6 +1939,7 @@ Tenant ID: {tenant_id}
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
status: ConnectorCredentialPairStatus
|
||||
|
||||
|
||||
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
|
||||
@@ -1924,13 +1953,17 @@ def get_basic_connector_indexing_status(
|
||||
get_editable=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# NOTE: This endpoint excludes Craft connectors
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
status=cc_pair.status,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.connector.source != DocumentSource.INGESTION_API
|
||||
and cc_pair.processing_mode == ProcessingMode.REGULAR
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -365,7 +365,8 @@ class CCPairFullInfo(BaseModel):
|
||||
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
|
||||
num_docs_indexed=num_docs_indexed,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_model.connector
|
||||
cc_pair_model.connector,
|
||||
credential_ids=[cc_pair_model.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_model.credential
|
||||
|
||||
@@ -111,7 +111,8 @@ class DocumentSet(BaseModel):
|
||||
id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector
|
||||
cc_pair.connector,
|
||||
credential_ids=[cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair.credential
|
||||
|
||||
@@ -36,6 +36,8 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
|
||||
@@ -50,6 +52,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -377,6 +380,37 @@ def create_memory_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_python_tool_packets(
|
||||
code: str,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
file_ids: list[str],
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
) -> list[Packet]:
|
||||
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
|
||||
tool call data so the frontend can display both the code and its output
|
||||
on page reload."""
|
||||
packets: list[Packet] = []
|
||||
placement = Placement(turn_index=turn_index, tab_index=tab_index)
|
||||
|
||||
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=placement, obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
def create_search_packets(
|
||||
search_queries: list[str],
|
||||
search_docs: list[SavedSearchDoc],
|
||||
@@ -586,6 +620,41 @@ def translate_assistant_message_to_packets(
|
||||
)
|
||||
)
|
||||
|
||||
elif tool.in_code_tool_id == PythonTool.__name__:
|
||||
code = cast(
|
||||
str,
|
||||
tool_call.tool_call_arguments.get("code", ""),
|
||||
)
|
||||
stdout = ""
|
||||
stderr = ""
|
||||
file_ids: list[str] = []
|
||||
if tool_call.tool_call_response:
|
||||
try:
|
||||
response_data = json.loads(tool_call.tool_call_response)
|
||||
stdout = response_data.get("stdout", "")
|
||||
stderr = response_data.get("stderr", "")
|
||||
generated_files = response_data.get(
|
||||
"generated_files", []
|
||||
)
|
||||
file_ids = [
|
||||
f.get("file_link", "").split("/")[-1]
|
||||
for f in generated_files
|
||||
if f.get("file_link")
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Fall back to raw response as stdout
|
||||
stdout = tool_call.tool_call_response
|
||||
turn_tool_packets.extend(
|
||||
create_python_tool_packets(
|
||||
code=code,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
file_ids=file_ids,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
turn_tool_packets.extend(
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SAML_CONF_DIR
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
@@ -123,9 +124,12 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
if request.client is None:
|
||||
raise ValueError("Invalid request for SAML")
|
||||
|
||||
# Use X-Forwarded headers if available
|
||||
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
|
||||
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
|
||||
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
|
||||
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
|
||||
# to poison SAML redirect URLs (host header poisoning).
|
||||
parsed_domain = urlparse(WEB_DOMAIN)
|
||||
http_host = parsed_domain.hostname or request.client.host
|
||||
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
|
||||
|
||||
rv: dict[str, Any] = {
|
||||
"http_host": http_host,
|
||||
|
||||
@@ -57,6 +57,7 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
|
||||
@@ -199,6 +199,12 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -171,10 +171,8 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
# TODO concerning passing the db_session here.
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
@@ -422,7 +420,6 @@ def construct_tools(
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
|
||||
@@ -11,11 +11,14 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_PROVIDER
|
||||
from onyx.db.image_generation import get_default_image_generation_config
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import load_chat_file_by_id
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
@@ -23,6 +26,7 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -31,6 +35,7 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.models import ImageGenerationResponse
|
||||
from onyx.tools.tool_implementations.images.models import ImageShape
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -40,10 +45,10 @@ logger = setup_logger()
|
||||
HEARTBEAT_INTERVAL = 5.0
|
||||
|
||||
PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -59,6 +64,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.img_provider = get_image_generation_provider(
|
||||
@@ -133,6 +139,16 @@ class ImageGenerationTool(Tool[None]):
|
||||
),
|
||||
"enum": [shape.value for shape in ImageShape],
|
||||
},
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": [PROMPT_FIELD],
|
||||
},
|
||||
@@ -148,7 +164,10 @@ class ImageGenerationTool(Tool[None]):
|
||||
)
|
||||
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape
|
||||
self,
|
||||
prompt: str,
|
||||
shape: ImageShape,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
) -> tuple[ImageGenerationResponse, Any]:
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
if "gpt-image-1" in self.model:
|
||||
@@ -169,6 +188,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
model=self.model,
|
||||
size=size,
|
||||
n=1,
|
||||
reference_images=reference_images,
|
||||
# response_format parameter is not supported for gpt-image-1
|
||||
response_format=None if "gpt-image-1" in self.model else "b64_json",
|
||||
)
|
||||
@@ -231,10 +251,117 @@ class ImageGenerationTool(Tool[None]):
|
||||
emit_error_packet=True,
|
||||
)
|
||||
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, "
|
||||
f"got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
|
||||
if not deduped_reference_image_ids:
|
||||
return []
|
||||
|
||||
if not self.img_provider.supports_reference_images:
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Reference images requested but provider '{self.provider}' "
|
||||
"does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
self,
|
||||
reference_image_file_ids: list[str],
|
||||
) -> list[ReferenceImage]:
|
||||
reference_images: list[ReferenceImage] = []
|
||||
|
||||
for file_id in reference_image_file_ids:
|
||||
try:
|
||||
loaded_file = load_chat_file_by_id(file_id)
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
message=f"Could not load reference image file '{file_id}': {e}",
|
||||
llm_facing_message=(
|
||||
f"Reference image file '{file_id}' could not be loaded. "
|
||||
"Use file_id values returned by previous generate_image calls."
|
||||
),
|
||||
)
|
||||
|
||||
if loaded_file.file_type != ChatFileType.IMAGE:
|
||||
raise ToolCallException(
|
||||
message=f"Reference file '{file_id}' is not an image",
|
||||
llm_facing_message=f"Reference file '{file_id}' is not an image.",
|
||||
)
|
||||
|
||||
try:
|
||||
mime_type = get_image_type_from_bytes(loaded_file.content)
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
message=f"Unsupported reference image format for '{file_id}': {e}",
|
||||
llm_facing_message=(
|
||||
f"Reference image '{file_id}' has an unsupported format. "
|
||||
"Only PNG, JPEG, GIF, and WEBP are supported."
|
||||
),
|
||||
)
|
||||
|
||||
reference_images.append(
|
||||
ReferenceImage(
|
||||
data=loaded_file.content,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
|
||||
return reference_images
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -247,6 +374,11 @@ class ImageGenerationTool(Tool[None]):
|
||||
)
|
||||
prompt = cast(str, llm_kwargs[PROMPT_FIELD])
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
# Use threading to generate images in parallel while emitting heartbeats
|
||||
results: list[tuple[ImageGenerationResponse, Any] | None] = [
|
||||
@@ -267,6 +399,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
(
|
||||
prompt,
|
||||
shape,
|
||||
reference_images or None,
|
||||
),
|
||||
)
|
||||
for _ in range(self.num_imgs)
|
||||
@@ -347,6 +480,7 @@ class ImageGenerationTool(Tool[None]):
|
||||
llm_facing_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"file_id": img.file_id,
|
||||
"revised_prompt": img.revised_prompt,
|
||||
}
|
||||
for img in generated_images_metadata
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -21,10 +22,22 @@ from onyx.utils.web_content import title_from_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_READ_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def _failed_result(url: str) -> WebContent:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
|
||||
class OnyxWebCrawler(WebContentProvider):
|
||||
@@ -37,12 +50,14 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
max_pdf_size_bytes: int | None = None,
|
||||
max_html_size_bytes: int | None = None,
|
||||
) -> None:
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._read_timeout_seconds = timeout_seconds
|
||||
self._connect_timeout_seconds = connect_timeout_seconds
|
||||
self._max_pdf_size_bytes = max_pdf_size_bytes
|
||||
self._max_html_size_bytes = max_html_size_bytes
|
||||
self._headers = {
|
||||
@@ -51,75 +66,68 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._fetch_url_safe, urls))
|
||||
|
||||
def _fetch_url_safe(self, url: str) -> WebContent:
|
||||
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
|
||||
try:
|
||||
return self._fetch_url(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler unexpected error for %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
# Use SSRF-safe request to prevent DNS rebinding attacks
|
||||
response = ssrf_safe_get(
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
url,
|
||||
headers=self._headers,
|
||||
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
|
||||
)
|
||||
except SSRFException as exc:
|
||||
logger.error(
|
||||
"SSRF protection blocked request to %s: %s",
|
||||
"SSRF protection blocked request to %s (%s)",
|
||||
url,
|
||||
str(exc),
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
return _failed_result(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content_sniff = response.content[:1024] if response.content else None
|
||||
content = response.content
|
||||
|
||||
content_sniff = content[:1024] if content else None
|
||||
if is_pdf_resource(url, content_type, content_sniff):
|
||||
if (
|
||||
self._max_pdf_size_bytes is not None
|
||||
and len(response.content) > self._max_pdf_size_bytes
|
||||
and len(content) > self._max_pdf_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"PDF content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_pdf_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
text_content, metadata = extract_pdf_text(response.content)
|
||||
return _failed_result(url)
|
||||
text_content, metadata = extract_pdf_text(content)
|
||||
title = title_from_pdf_metadata(metadata) or title_from_url(url)
|
||||
return WebContent(
|
||||
title=title,
|
||||
@@ -131,25 +139,19 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
|
||||
if (
|
||||
self._max_html_size_bytes is not None
|
||||
and len(response.content) > self._max_html_size_bytes
|
||||
and len(content) > self._max_html_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"HTML content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_html_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
try:
|
||||
decoded_html = decode_html_bytes(
|
||||
response.content,
|
||||
content,
|
||||
content_type=content_type,
|
||||
fallback_encoding=response.apparent_encoding or response.encoding,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
@@ -13,6 +14,7 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import PythonToolOverrideKwargs
|
||||
@@ -22,6 +24,9 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
@@ -105,6 +110,63 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _extract_image_file_ids_from_tool_response_message(
|
||||
message: str,
|
||||
) -> list[str]:
|
||||
try:
|
||||
parsed_message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
parsed_items: list[Any] = (
|
||||
parsed_message if isinstance(parsed_message, list) else [parsed_message]
|
||||
)
|
||||
file_ids: list[str] = []
|
||||
for item in parsed_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
file_id = item.get("file_id")
|
||||
if isinstance(file_id, str):
|
||||
file_ids.append(file_id)
|
||||
|
||||
return file_ids
|
||||
|
||||
|
||||
def _extract_recent_generated_image_file_ids(
|
||||
message_history: list[ChatMessageSimple],
|
||||
) -> list[str]:
|
||||
tool_name_by_tool_call_id: dict[str, str] = {}
|
||||
recent_image_file_ids: list[str] = []
|
||||
seen_file_ids: set[str] = set()
|
||||
|
||||
for message in message_history:
|
||||
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
|
||||
continue
|
||||
|
||||
if (
|
||||
message.message_type != MessageType.TOOL_CALL_RESPONSE
|
||||
or not message.tool_call_id
|
||||
):
|
||||
continue
|
||||
|
||||
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
|
||||
if tool_name != ImageGenerationTool.NAME:
|
||||
continue
|
||||
|
||||
for file_id in _extract_image_file_ids_from_tool_response_message(
|
||||
message.message
|
||||
):
|
||||
if file_id in seen_file_ids:
|
||||
continue
|
||||
seen_file_ids.add(file_id)
|
||||
recent_image_file_ids.append(file_id)
|
||||
|
||||
return recent_image_file_ids
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
@@ -324,6 +386,9 @@ def run_tool_calls(
|
||||
url_to_citation: dict[str, int] = {
|
||||
url: citation_num for citation_num, url in citation_mapping.items()
|
||||
}
|
||||
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Prepare all tool calls with their override_kwargs
|
||||
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
|
||||
@@ -340,6 +405,7 @@ def run_tool_calls(
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| PythonToolOverrideKwargs
|
||||
| ImageGenerationToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
@@ -388,6 +454,10 @@ def run_tool_calls(
|
||||
override_kwargs = PythonToolOverrideKwargs(
|
||||
chat_files=chat_files or [],
|
||||
)
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
override_kwargs = ImageGenerationToolOverrideKwargs(
|
||||
recent_generated_image_file_ids=recent_generated_image_file_ids
|
||||
)
|
||||
elif isinstance(tool, MemoryTool):
|
||||
override_kwargs = MemoryToolOverrideKwargs(
|
||||
user_name=(
|
||||
|
||||
@@ -146,7 +146,7 @@ MAX_REDIRECTS = 10
|
||||
def _make_ssrf_safe_request(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
@@ -204,7 +204,7 @@ def _make_ssrf_safe_request(
|
||||
def ssrf_safe_get(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
follow_redirects: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.5.7
|
||||
onyx-devtools==0.6.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -217,8 +217,8 @@ class TestAutoModeSyncFeature:
|
||||
),
|
||||
additional_visible_models=[
|
||||
SimpleKnownModel(
|
||||
name="claude-3-5-haiku-latest",
|
||||
display_name="Claude 3.5 Haiku",
|
||||
name="claude-haiku-4-5",
|
||||
display_name="Claude Haiku 4.5",
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -260,7 +260,7 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
# Anthropic models should NOT be present
|
||||
assert "claude-3-5-sonnet-latest" not in model_names
|
||||
assert "claude-3-5-haiku-latest" not in model_names
|
||||
assert "claude-haiku-4-5" not in model_names
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -485,7 +485,7 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
# Provider 2 (Anthropic) config
|
||||
provider_2_default_model = "claude-3-5-sonnet-latest"
|
||||
provider_2_additional_models = ["claude-3-5-haiku-latest"]
|
||||
provider_2_additional_models = ["claude-haiku-4-5"]
|
||||
|
||||
# Create mock recommendations with both providers
|
||||
mock_recommendations = LLMRecommendations(
|
||||
|
||||
@@ -281,15 +281,22 @@ def test_anthropic_prompt_caching_reduces_costs(
|
||||
|
||||
Anthropic requires explicit cache_control parameters.
|
||||
"""
|
||||
# Create Anthropic LLM
|
||||
# NOTE: prompt caching support is model-specific; `claude-3-haiku-20240307` is known
|
||||
# to return cache_creation/cache_read usage metrics, while some newer aliases may not.
|
||||
llm = LitellmLLM(
|
||||
api_key=os.environ["ANTHROPIC_API_KEY"],
|
||||
model_provider="anthropic",
|
||||
model_name="claude-3-haiku-20240307",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
# Prompt caching support is model/account specific.
|
||||
# Allow override via env var and otherwise try a few non-retired candidates.
|
||||
anthropic_prompt_cache_models_env = os.environ.get("ANTHROPIC_PROMPT_CACHE_MODELS")
|
||||
if anthropic_prompt_cache_models_env:
|
||||
candidate_models = [
|
||||
model.strip()
|
||||
for model in anthropic_prompt_cache_models_env.split(",")
|
||||
if model.strip()
|
||||
]
|
||||
else:
|
||||
candidate_models = [
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-sonnet-latest",
|
||||
]
|
||||
|
||||
import random
|
||||
import string
|
||||
@@ -315,79 +322,107 @@ def test_anthropic_prompt_caching_reduces_costs(
|
||||
UserMessage(role="user", content=long_context)
|
||||
]
|
||||
|
||||
# First call - creates cache
|
||||
print("\n=== First call (cache creation) ===")
|
||||
question1: list[ChatCompletionMessage] = [
|
||||
UserMessage(role="user", content="What are the main topics discussed?")
|
||||
]
|
||||
unavailable_models: list[str] = []
|
||||
non_caching_models: list[str] = []
|
||||
|
||||
# Apply prompt caching
|
||||
processed_messages1, _ = process_with_prompt_cache(
|
||||
llm_config=llm.config,
|
||||
cacheable_prefix=base_messages,
|
||||
suffix=question1,
|
||||
continuation=False,
|
||||
for model_name in candidate_models:
|
||||
llm = LitellmLLM(
|
||||
api_key=os.environ["ANTHROPIC_API_KEY"],
|
||||
model_provider="anthropic",
|
||||
model_name=model_name,
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
# First call - creates cache
|
||||
print(f"\n=== First call (cache creation) model={model_name} ===")
|
||||
question1: list[ChatCompletionMessage] = [
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Reply with exactly one lowercase word: topics",
|
||||
)
|
||||
]
|
||||
|
||||
processed_messages1, _ = process_with_prompt_cache(
|
||||
llm_config=llm.config,
|
||||
cacheable_prefix=base_messages,
|
||||
suffix=question1,
|
||||
continuation=False,
|
||||
)
|
||||
|
||||
try:
|
||||
response1 = llm.invoke(prompt=processed_messages1, max_tokens=8)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"not_found_error" in error_str
|
||||
or "model_not_found" in error_str
|
||||
or ('"type":"not_found_error"' in error_str and "model:" in error_str)
|
||||
):
|
||||
unavailable_models.append(model_name)
|
||||
continue
|
||||
raise
|
||||
|
||||
cost1 = completion_cost(
|
||||
completion_response=response1.model_dump(),
|
||||
model=f"{llm._model_provider}/{llm._model_version}",
|
||||
)
|
||||
|
||||
usage1 = response1.usage
|
||||
print(f"Response 1 usage: {usage1}")
|
||||
print(f"Cost 1: ${cost1:.10f}")
|
||||
|
||||
# Wait to ensure cache is available
|
||||
time.sleep(2)
|
||||
|
||||
# Second call with same context - should use cache
|
||||
print(f"\n=== Second call (cache read) model={model_name} ===")
|
||||
question2: list[ChatCompletionMessage] = [
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Reply with exactly one lowercase word: neural",
|
||||
)
|
||||
]
|
||||
|
||||
processed_messages2, _ = process_with_prompt_cache(
|
||||
llm_config=llm.config,
|
||||
cacheable_prefix=base_messages,
|
||||
suffix=question2,
|
||||
continuation=False,
|
||||
)
|
||||
|
||||
response2 = llm.invoke(prompt=processed_messages2, max_tokens=8)
|
||||
cost2 = completion_cost(
|
||||
completion_response=response2.model_dump(),
|
||||
model=f"{llm._model_provider}/{llm._model_version}",
|
||||
)
|
||||
|
||||
usage2 = response2.usage
|
||||
print(f"Response 2 usage: {usage2}")
|
||||
print(f"Cost 2: ${cost2:.10f}")
|
||||
|
||||
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
|
||||
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
|
||||
|
||||
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
|
||||
print(f"Cache read tokens (call 2): {cache_read_tokens}")
|
||||
print(f"Cost reduction: ${cost1 - cost2:.10f}")
|
||||
|
||||
# Model is available but does not expose Anthropic cache usage metrics
|
||||
if cache_creation_tokens <= 0 or cache_read_tokens <= 0:
|
||||
non_caching_models.append(model_name)
|
||||
continue
|
||||
|
||||
# Cost should be lower on second call
|
||||
assert (
|
||||
cost2 < cost1
|
||||
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
|
||||
return
|
||||
|
||||
pytest.skip(
|
||||
"No Anthropic model available with observable prompt-cache metrics. "
|
||||
f"Tried models={candidate_models}, unavailable={unavailable_models}, non_caching={non_caching_models}"
|
||||
)
|
||||
|
||||
response1 = llm.invoke(prompt=processed_messages1)
|
||||
cost1 = completion_cost(
|
||||
completion_response=response1.model_dump(),
|
||||
model=f"{llm._model_provider}/{llm._model_version}",
|
||||
)
|
||||
|
||||
usage1 = response1.usage
|
||||
print(f"Response 1 usage: {usage1}")
|
||||
print(f"Cost 1: ${cost1:.10f}")
|
||||
|
||||
# Wait to ensure cache is available
|
||||
time.sleep(2)
|
||||
|
||||
# Second call with same context - should use cache
|
||||
print("\n=== Second call (cache read) ===")
|
||||
question2: list[ChatCompletionMessage] = [
|
||||
UserMessage(role="user", content="Can you elaborate on neural networks?")
|
||||
]
|
||||
|
||||
# Apply prompt caching (same cacheable prefix)
|
||||
processed_messages2, _ = process_with_prompt_cache(
|
||||
llm_config=llm.config,
|
||||
cacheable_prefix=base_messages,
|
||||
suffix=question2,
|
||||
continuation=False,
|
||||
)
|
||||
|
||||
response2 = llm.invoke(prompt=processed_messages2)
|
||||
cost2 = completion_cost(
|
||||
completion_response=response2.model_dump(),
|
||||
model=f"{llm._model_provider}/{llm._model_version}",
|
||||
)
|
||||
|
||||
usage2 = response2.usage
|
||||
print(f"Response 2 usage: {usage2}")
|
||||
print(f"Cost 2: ${cost2:.10f}")
|
||||
|
||||
# Verify caching occurred
|
||||
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
|
||||
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
|
||||
|
||||
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
|
||||
print(f"Cache read tokens (call 2): {cache_read_tokens}")
|
||||
print(f"Cost reduction: ${cost1 - cost2:.10f}")
|
||||
|
||||
# For Anthropic, we should see cache creation on first call and cache reads on second
|
||||
assert (
|
||||
cache_creation_tokens > 0
|
||||
), f"Expected cache creation tokens on first call. Got: {cache_creation_tokens}"
|
||||
|
||||
assert (
|
||||
cache_read_tokens > 0
|
||||
), f"Expected cache read tokens on second call. Got: {cache_read_tokens}"
|
||||
|
||||
# Cost should be lower on second call
|
||||
assert (
|
||||
cost2 < cost1
|
||||
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get(VERTEX_CREDENTIALS_ENV),
|
||||
|
||||
@@ -13,6 +13,7 @@ from litellm.types.utils import ImageResponse
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
|
||||
|
||||
@@ -62,6 +63,7 @@ class MockImageGenerationProvider(
|
||||
size: str, # noqa: ARG002
|
||||
n: int, # noqa: ARG002
|
||||
quality: str | None = None, # noqa: ARG002
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any, # noqa: ARG002
|
||||
) -> ImageResponse:
|
||||
image_data = self._images.pop(0)
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -12,9 +13,13 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def run_functions_tuples_sequential(
|
||||
@@ -135,13 +140,25 @@ def use_mock_search_pipeline(
|
||||
document_index: DocumentIndex, # noqa: ARG001
|
||||
user: User | None, # noqa: ARG001
|
||||
persona: Persona | None, # noqa: ARG001
|
||||
db_session: Session, # noqa: ARG001
|
||||
db_session: Session | None = None, # noqa: ARG001
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
list[FederatedRetrievalInfo] | None
|
||||
) = None,
|
||||
) -> list[InferenceChunk]:
|
||||
return controller.get_search_results(chunk_search_request.query)
|
||||
|
||||
# Mock the pre-fetch session and DB queries in SearchTool.run() so
|
||||
# tests don't need a fully initialised DB with search settings.
|
||||
@contextmanager
|
||||
def mock_get_session() -> Generator[MagicMock, None, None]:
|
||||
yield MagicMock(spec=Session)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.search_pipeline",
|
||||
@@ -183,5 +200,31 @@ def use_mock_search_pipeline(
|
||||
"onyx.db.connector.fetch_unique_document_sources",
|
||||
new=mock_fetch_unique_document_sources,
|
||||
),
|
||||
# Mock the pre-fetch phase of SearchTool.run()
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_session_with_current_tenant",
|
||||
new=mock_get_session,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.build_access_filters_for_user",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_current_search_settings",
|
||||
return_value=MagicMock(spec=SearchSettings),
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.EmbeddingModel.from_db_model",
|
||||
return_value=MagicMock(spec=EmbeddingModel),
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_federated_retrieval_functions",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
SearchTool,
|
||||
"_prefetch_slack_data",
|
||||
return_value=(None, None, {}),
|
||||
),
|
||||
):
|
||||
yield controller
|
||||
|
||||
@@ -943,10 +943,18 @@ from onyx.db.tools import get_builtin_tool
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.server.features.projects.api import upload_user_files
|
||||
from onyx.server.query_and_chat.chat_backend import get_chat_session
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_placement
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.mock_llm import LLMAnswerResponse
|
||||
from tests.external_dependency_unit.mock_llm import LLMToolCallResponse
|
||||
from tests.external_dependency_unit.mock_llm import use_mock_llm
|
||||
|
||||
@@ -1174,3 +1182,134 @@ def test_code_interpreter_receives_chat_files(
|
||||
assert execute_body["code"] == code
|
||||
assert len(execute_body["files"]) == 1
|
||||
assert execute_body["files"][0]["path"] == "data.csv"
|
||||
|
||||
|
||||
def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
db_session: Session,
|
||||
mock_ci_server: MockCodeInterpreterServer,
|
||||
_attach_python_tool_to_default_persona: None,
|
||||
initialize_file_store: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""After a code interpreter message completes, retrieving the message
|
||||
via translate_assistant_message_to_packets should emit PythonToolStart
|
||||
(containing the executed code) and PythonToolDelta (containing
|
||||
stdout/stderr), not generic CustomTool packets."""
|
||||
mock_ci_server.captured_requests.clear()
|
||||
mock_ci_server._file_counter = 0
|
||||
mock_url = mock_ci_server.url
|
||||
|
||||
user = create_test_user(db_session, "ci_replay_test")
|
||||
chat_session = create_chat_session(db_session=db_session, user=user)
|
||||
|
||||
code = 'x = 2 + 2\nprint(f"Result: {x}")'
|
||||
msg_req = SendMessageRequest(
|
||||
message="Calculate 2 + 2",
|
||||
chat_session_id=chat_session.id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
|
||||
with (
|
||||
use_mock_llm() as mock_llm,
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
|
||||
mock_url,
|
||||
),
|
||||
):
|
||||
answer_tokens = ["The ", "result ", "is ", "4."]
|
||||
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
|
||||
try:
|
||||
handler = StreamTestBuilder(llm_controller=mock_llm)
|
||||
|
||||
stream = handle_stream_message_objects(
|
||||
new_msg_req=msg_req, user=user, db_session=db_session
|
||||
)
|
||||
# First packet is always MessageResponseIDInfo
|
||||
next(stream)
|
||||
|
||||
# Phase 1: LLM requests python tool execution.
|
||||
handler.add_response(
|
||||
LLMToolCallResponse(
|
||||
tool_name="python",
|
||||
tool_call_id="call_replay_test",
|
||||
tool_call_argument_tokens=[json.dumps({"code": code})],
|
||||
)
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolDelta(stdout="mock output\n", stderr="", file_ids=[]),
|
||||
),
|
||||
forward=False,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=SectionEnd(),
|
||||
),
|
||||
forward=False,
|
||||
).run_and_validate(
|
||||
stream=stream
|
||||
)
|
||||
|
||||
# Phase 2: LLM produces a final answer after tool execution.
|
||||
handler.add_response(
|
||||
LLMAnswerResponse(answer_tokens=answer_tokens)
|
||||
).expect_agent_response(
|
||||
answer_tokens=answer_tokens,
|
||||
turn_index=1,
|
||||
).run_and_validate(
|
||||
stream=stream
|
||||
)
|
||||
|
||||
with pytest.raises(StopIteration):
|
||||
next(stream)
|
||||
|
||||
finally:
|
||||
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
|
||||
|
||||
# Retrieve the chat session through the same endpoint the frontend uses
|
||||
chat_detail = get_chat_session(
|
||||
session_id=chat_session.id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
|
||||
|
||||
# The response contains `packets` — a list of packet-lists, one per
|
||||
# assistant message. We should have exactly one assistant message.
|
||||
assert (
|
||||
len(chat_detail.packets) == 1
|
||||
), f"Expected 1 assistant packet list, got {len(chat_detail.packets)}"
|
||||
packets = chat_detail.packets[0]
|
||||
|
||||
# Extract PythonToolStart packets – these must contain the code
|
||||
start_packets = [p for p in packets if isinstance(p.obj, PythonToolStart)]
|
||||
assert len(start_packets) == 1, (
|
||||
f"Expected 1 PythonToolStart packet, got {len(start_packets)}. "
|
||||
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
|
||||
)
|
||||
start_obj = start_packets[0].obj
|
||||
assert isinstance(start_obj, PythonToolStart)
|
||||
assert start_obj.code == code
|
||||
|
||||
# Extract PythonToolDelta packets – these must contain stdout/stderr
|
||||
delta_packets = [p for p in packets if isinstance(p.obj, PythonToolDelta)]
|
||||
assert len(delta_packets) >= 1, (
|
||||
f"Expected at least 1 PythonToolDelta packet, got {len(delta_packets)}. "
|
||||
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
|
||||
)
|
||||
# The mock CI server returns "mock output\n" as stdout
|
||||
delta_obj = delta_packets[0].obj
|
||||
assert isinstance(delta_obj, PythonToolDelta)
|
||||
assert "mock output" in delta_obj.stdout
|
||||
|
||||
@@ -13,9 +13,9 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class APIKeyManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
api_key_role: UserRole = UserRole.ADMIN,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestAPIKey:
|
||||
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
|
||||
api_key_request = APIKeyArgs(
|
||||
@@ -25,11 +25,7 @@ class APIKeyManager:
|
||||
api_key_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
json=api_key_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
api_key = api_key_response.json()
|
||||
@@ -48,29 +44,21 @@ class APIKeyManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
api_key: DATestAPIKey,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
api_key_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestAPIKey]:
|
||||
api_key_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
|
||||
@@ -78,8 +66,8 @@ class APIKeyManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
api_key: DATestAPIKey,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_keys = APIKeyManager.get_all(
|
||||
user_performing_action=user_performing_action
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.server.documents.models import DocumentSource
|
||||
from onyx.server.documents.models import DocumentSyncStatus
|
||||
from tests.integration.common_utils.config import api_config
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
@@ -28,10 +27,10 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
def _cc_pair_creator(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
|
||||
|
||||
@@ -40,17 +39,12 @@ def _cc_pair_creator(
|
||||
connector_credential_pair_metadata = api.ConnectorCredentialPairMetadata(
|
||||
name=name, access_type=access_type, groups=groups or []
|
||||
)
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
api_response: api.StatusResponseInt = (
|
||||
api_instance.associate_credential_to_connector(
|
||||
connector_id,
|
||||
credential_id,
|
||||
connector_credential_pair_metadata,
|
||||
_headers=headers,
|
||||
_headers=user_performing_action.headers,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -67,6 +61,7 @@ def _cc_pair_creator(
|
||||
class CCPairManager:
|
||||
@staticmethod
|
||||
def create_from_scratch(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
@@ -74,26 +69,25 @@ class CCPairManager:
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestCCPair:
|
||||
connector = ConnectorManager.create(
|
||||
user_performing_action=user_performing_action,
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config,
|
||||
access_type=access_type,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
refresh_freq=refresh_freq,
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
user_performing_action=user_performing_action,
|
||||
credential_json=credential_json,
|
||||
name=name,
|
||||
source=source,
|
||||
curator_public=(access_type == AccessType.PUBLIC),
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector.id,
|
||||
@@ -109,10 +103,10 @@ class CCPairManager:
|
||||
def create(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector_id,
|
||||
@@ -127,39 +121,31 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def pause_cc_pair(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "PAUSED"},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def unpause_cc_pair(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "ACTIVE"},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -168,26 +154,18 @@ class CCPairManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
|
||||
json=cc_pair_identifier.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_single(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> CCPairFullInfo | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
cc_pair_json = response.json()
|
||||
@@ -196,15 +174,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_indexing_status_by_id(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
json={"get_all_connectors": True},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -219,15 +193,11 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def get_indexing_statuses(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[ConnectorIndexingStatusLite]:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
json={"get_all_connectors": True},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -241,15 +211,11 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def get_connector_statuses(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[ConnectorStatus]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ConnectorStatus(**status) for status in response.json()]
|
||||
@@ -257,8 +223,8 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_cc_pairs = CCPairManager.get_connector_statuses(user_performing_action)
|
||||
for retrieved_cc_pair in all_cc_pairs:
|
||||
@@ -285,7 +251,7 @@ class CCPairManager:
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
from_beginning: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
@@ -295,19 +261,15 @@ class CCPairManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
json=body,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_inactive(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
@@ -342,9 +304,9 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def wait_for_indexing_in_progress(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
num_docs: int = 16,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
@@ -393,8 +355,8 @@ class CCPairManager:
|
||||
def wait_for_indexing_completion(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: Wait for an indexing success time after this time"""
|
||||
start = time.monotonic()
|
||||
@@ -430,30 +392,22 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def prune(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def last_pruned(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_str = response.json()
|
||||
@@ -471,8 +425,8 @@ class CCPairManager:
|
||||
def wait_for_prune(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
@@ -496,7 +450,7 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def sync(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""This function triggers a permission sync.
|
||||
Naming / intent of this function probably could use improvement, but currently it's letting
|
||||
@@ -504,22 +458,14 @@ class CCPairManager:
|
||||
"""
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if result.status_code != 409:
|
||||
result.raise_for_status()
|
||||
|
||||
group_sync_result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if group_sync_result.status_code != 409:
|
||||
group_sync_result.raise_for_status()
|
||||
@@ -528,15 +474,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_doc_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
doc_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
doc_sync_response.raise_for_status()
|
||||
doc_sync_response_str = doc_sync_response.json()
|
||||
@@ -553,15 +495,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_group_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
group_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
group_sync_response.raise_for_status()
|
||||
group_sync_response_str = group_sync_response.json()
|
||||
@@ -578,15 +516,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_doc_sync_statuses(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DocumentSyncStatus]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
doc_sync_statuses: list[DocumentSyncStatus] = []
|
||||
@@ -613,9 +547,9 @@ class CCPairManager:
|
||||
def wait_for_sync(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
number_of_updated_docs: int = 0,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
# Sometimes waiting for a group sync is not necessary
|
||||
should_wait_for_group_sync: bool = True,
|
||||
# Sometimes waiting for a vespa sync is not necessary
|
||||
@@ -703,8 +637,8 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: DATestUser,
|
||||
cc_pair_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
|
||||
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestChatMessage
|
||||
from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -74,9 +73,9 @@ class StreamPacketData(TypedDict, total=False):
|
||||
class ChatSessionManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
@@ -84,11 +83,7 @@ class ChatSessionManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
json=chat_session_creation_req.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
@@ -100,8 +95,8 @@ class ChatSessionManager:
|
||||
def send_message(
|
||||
chat_session_id: UUID,
|
||||
message: str,
|
||||
user_performing_action: DATestUser,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
@@ -126,19 +121,12 @@ class ChatSessionManager:
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
|
||||
streamed_response = ChatSessionManager.analyze_response(response)
|
||||
@@ -167,9 +155,9 @@ class ChatSessionManager:
|
||||
def send_message_with_disconnect(
|
||||
chat_session_id: UUID,
|
||||
message: str,
|
||||
user_performing_action: DATestUser,
|
||||
disconnect_after_packets: int = 0,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
@@ -208,21 +196,14 @@ class ChatSessionManager:
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
packets_received = 0
|
||||
|
||||
with requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
cookies=user_performing_action.cookies,
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
@@ -359,15 +340,11 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def get_chat_history(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestChatMessage]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -387,7 +364,7 @@ class ChatSessionManager:
|
||||
def create_chat_message_feedback(
|
||||
message_id: int,
|
||||
is_positive: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
feedback_text: str | None = None,
|
||||
predefined_feedback: str | None = None,
|
||||
) -> None:
|
||||
@@ -399,18 +376,14 @@ class ChatSessionManager:
|
||||
"feedback_text": feedback_text,
|
||||
"predefined_feedback": predefined_feedback,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a chat session and all its related records (messages, agent data, etc.)
|
||||
@@ -420,18 +393,14 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def soft_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Soft delete a chat session (marks as deleted but keeps in database).
|
||||
@@ -442,18 +411,14 @@ class ChatSessionManager:
|
||||
# or make a direct call with hard_delete=False parameter via a new endpoint
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def hard_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Hard delete a chat session (completely removes from database).
|
||||
@@ -462,18 +427,14 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def verify_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been deleted by attempting to retrieve it.
|
||||
@@ -482,11 +443,7 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
@@ -494,7 +451,7 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
|
||||
@@ -504,11 +461,7 @@ class ChatSessionManager:
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -520,7 +473,7 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def verify_hard_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been hard deleted (completely removed from DB).
|
||||
@@ -530,11 +483,7 @@ class ChatSessionManager:
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
# For hard delete, even with include_deleted=true, the record should not exist
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import ConnectorUpdateRequest
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -16,13 +15,13 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class ConnectorManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestConnector:
|
||||
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
|
||||
@@ -51,11 +50,7 @@ class ConnectorManager:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector",
|
||||
json=connector_update_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -73,45 +68,33 @@ class ConnectorManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
json=connector.model_dump(exclude={"id"}),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestConnector]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
@@ -127,15 +110,12 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
connector_id: int, user_performing_action: DATestUser | None = None
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestConnector:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
conn = response.json()
|
||||
|
||||
@@ -6,7 +6,6 @@ import requests
|
||||
from onyx.server.documents.models import CredentialSnapshot
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -14,13 +13,13 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class CredentialManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
admin_public: bool = True,
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
curator_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCredential:
|
||||
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
|
||||
|
||||
@@ -36,11 +35,7 @@ class CredentialManager:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/credential",
|
||||
json=credential_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -57,61 +52,46 @@ class CredentialManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
request = credential.model_dump(include={"name", "credential_json"})
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
credential_id: int, user_performing_action: DATestUser | None = None
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> CredentialSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CredentialSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[CredentialSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/credential",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [CredentialSnapshot(**cred) for cred in response.json()]
|
||||
@@ -119,8 +99,8 @@ class CredentialManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_credentials = CredentialManager.get_all(user_performing_action)
|
||||
for fetched_credential in all_credentials:
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import DATestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
|
||||
@@ -22,9 +21,9 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
def _verify_document_permissions(
|
||||
retrieved_doc: dict,
|
||||
cc_pair: DATestCCPair,
|
||||
doc_creating_user: DATestUser,
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
) -> None:
|
||||
acl_keys = set(retrieved_doc.get("access_control_list", {}).keys())
|
||||
print(f"ACL keys: {acl_keys}")
|
||||
@@ -36,12 +35,11 @@ def _verify_document_permissions(
|
||||
" does not have the PUBLIC ACL key"
|
||||
)
|
||||
|
||||
if doc_creating_user is not None:
|
||||
if f"user_email:{doc_creating_user.email}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
|
||||
)
|
||||
if f"user_email:{doc_creating_user.email}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
|
||||
)
|
||||
|
||||
if group_names is not None:
|
||||
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
|
||||
@@ -101,9 +99,9 @@ class DocumentManager:
|
||||
@staticmethod
|
||||
def seed_dummy_docs(
|
||||
cc_pair: DATestCCPair,
|
||||
api_key: DATestAPIKey,
|
||||
num_docs: int = NUM_DOCS,
|
||||
document_ids: list[str] | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
) -> list[SimpleTestDocument]:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
if document_ids is None:
|
||||
@@ -118,12 +116,13 @@ class DocumentManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
json=document,
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
api_key_id = api_key.api_key_id if api_key else ""
|
||||
print(f"Seeding docs for api_key_id={api_key_id} completed successfully.")
|
||||
print(
|
||||
f"Seeding docs for api_key_id={api_key.api_key_id} completed successfully."
|
||||
)
|
||||
return [
|
||||
SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
@@ -136,8 +135,8 @@ class DocumentManager:
|
||||
def seed_doc_with_content(
|
||||
cc_pair: DATestCCPair,
|
||||
content: str,
|
||||
api_key: DATestAPIKey,
|
||||
document_id: str | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> SimpleTestDocument:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
@@ -153,12 +152,13 @@ class DocumentManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
json=document,
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
api_key_id = api_key.api_key_id if api_key else ""
|
||||
print(f"Seeding doc for api_key_id={api_key_id} completed successfully.")
|
||||
print(
|
||||
f"Seeding doc for api_key_id={api_key.api_key_id} completed successfully."
|
||||
)
|
||||
|
||||
return SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
@@ -169,11 +169,11 @@ class DocumentManager:
|
||||
def verify(
|
||||
vespa_client: vespa_fixture,
|
||||
cc_pair: DATestCCPair,
|
||||
doc_creating_user: DATestUser,
|
||||
# If None, will not check doc sets or groups
|
||||
# If empty list, will check for empty doc sets or groups
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
doc_ids = [document.id for document in cc_pair.documents]
|
||||
@@ -212,9 +212,9 @@ class DocumentManager:
|
||||
_verify_document_permissions(
|
||||
retrieved_doc,
|
||||
cc_pair,
|
||||
doc_creating_user,
|
||||
doc_set_names,
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -268,11 +268,11 @@ class IngestionManager(DocumentManager):
|
||||
|
||||
@staticmethod
|
||||
def list_all_ingestion_docs(
|
||||
api_key: DATestAPIKey | None = None,
|
||||
api_key: DATestAPIKey,
|
||||
) -> list[dict]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -280,11 +280,11 @@ class IngestionManager(DocumentManager):
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_id: str,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
api_key: DATestAPIKey,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion/{document_id}",
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(f"Deleted document {document_id} successfully.")
|
||||
|
||||
@@ -3,7 +3,6 @@ import requests
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -11,7 +10,7 @@ class DocumentSearchManager:
|
||||
@staticmethod
|
||||
def search_documents(
|
||||
query: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Search for documents using the EE search API.
|
||||
@@ -31,11 +30,7 @@ class DocumentSearchManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/search/send-search-message",
|
||||
json=search_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
result_json = result.json()
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestDocumentSet
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -15,6 +14,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class DocumentSetManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
@@ -22,7 +22,6 @@ class DocumentSetManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
federated_connectors: list[dict[str, Any]] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestDocumentSet:
|
||||
if name is None:
|
||||
name = f"test_doc_set_{str(uuid4())}"
|
||||
@@ -40,11 +39,7 @@ class DocumentSetManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_creation_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -63,7 +58,7 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
@@ -77,11 +72,7 @@ class DocumentSetManager:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_update_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
@@ -89,30 +80,22 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestDocumentSet]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/document-set",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
@@ -132,8 +115,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_performing_action: DATestUser,
|
||||
document_sets_to_check: list[DATestDocumentSet] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
# wait for document sets to be synced
|
||||
start = time.time()
|
||||
@@ -175,8 +158,8 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
doc_sets = DocumentSetManager.get_all(user_performing_action)
|
||||
for doc_set in doc_sets:
|
||||
|
||||
@@ -10,7 +10,6 @@ import requests
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -18,13 +17,9 @@ class FileManager:
|
||||
@staticmethod
|
||||
def upload_files(
|
||||
files: List[Tuple[str, IO]],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> Tuple[List[FileDescriptor], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
files_param = []
|
||||
@@ -67,15 +62,11 @@ class FileManager:
|
||||
@staticmethod
|
||||
def fetch_uploaded_file(
|
||||
file_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bytes:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{file_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestImageGenerationConfig
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -26,6 +25,7 @@ def _serialize_custom_config(
|
||||
class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
image_provider_id: str | None = None,
|
||||
model_name: str = "gpt-image-1",
|
||||
provider: str = "openai",
|
||||
@@ -35,7 +35,6 @@ class ImageGenerationConfigManager:
|
||||
deployment_name: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
is_default: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a new image generation config with new credentials."""
|
||||
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
|
||||
@@ -53,11 +52,7 @@ class ImageGenerationConfigManager:
|
||||
"custom_config": _serialize_custom_config(custom_config),
|
||||
"is_default": is_default,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -74,13 +69,13 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def create_from_provider(
|
||||
source_llm_provider_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
image_provider_id: str | None = None,
|
||||
model_name: str = "gpt-image-1",
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
is_default: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a new image generation config by cloning from an existing LLM provider."""
|
||||
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
|
||||
@@ -96,11 +91,7 @@ class ImageGenerationConfigManager:
|
||||
"deployment_name": deployment_name,
|
||||
"is_default": is_default,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -116,16 +107,12 @@ class ImageGenerationConfigManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestImageGenerationConfig]:
|
||||
"""Get all image generation configs."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [DATestImageGenerationConfig(**config) for config in response.json()]
|
||||
@@ -133,16 +120,12 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def get_credentials(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Get credentials for an image generation config."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/credentials",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -151,13 +134,13 @@ class ImageGenerationConfigManager:
|
||||
def update(
|
||||
image_provider_id: str,
|
||||
model_name: str,
|
||||
user_performing_action: DATestUser,
|
||||
provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
source_llm_provider_id: int | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Update an existing image generation config."""
|
||||
payload: dict = {
|
||||
@@ -178,14 +161,10 @@ class ImageGenerationConfigManager:
|
||||
f"Got: source_llm_provider_id={source_llm_provider_id}, provider={provider}, api_key={'***' if api_key else None}"
|
||||
)
|
||||
|
||||
headers = {**GENERAL_HEADERS}
|
||||
if user_performing_action:
|
||||
headers.update(user_performing_action.headers)
|
||||
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if not response.ok:
|
||||
print(f"Update failed with status {response.status_code}: {response.text}")
|
||||
@@ -204,40 +183,32 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""Delete an image generation config."""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def set_default(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""Set an image generation config as the default."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
config: DATestImageGenerationConfig,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Verify that a config exists (or doesn't exist if verify_deleted=True)."""
|
||||
all_configs = ImageGenerationConfigManager.get_all(user_performing_action)
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -86,9 +85,9 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_index_attempt_page(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
page: int = 0,
|
||||
page_size: int = 10,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[IndexAttemptSnapshot]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page,
|
||||
@@ -101,11 +100,7 @@ class IndexAttemptManager:
|
||||
)
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -117,7 +112,7 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> IndexAttemptSnapshot | None:
|
||||
"""Get an IndexAttempt by ID"""
|
||||
index_attempts = IndexAttemptManager.get_index_attempt_page(
|
||||
@@ -134,9 +129,9 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_start(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
index_attempts_to_ignore: list[int] | None = None,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
"""Wait for an IndexAttempt to start"""
|
||||
start = datetime.now()
|
||||
@@ -164,7 +159,7 @@ class IndexAttemptManager:
|
||||
def get_index_attempt_by_id(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> IndexAttemptSnapshot:
|
||||
page_num = 0
|
||||
page_size = 10
|
||||
@@ -190,8 +185,8 @@ class IndexAttemptManager:
|
||||
def wait_for_index_attempt_completion(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = time.monotonic()
|
||||
@@ -223,19 +218,15 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
include_resolved: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[IndexAttemptErrorPydantic]:
|
||||
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
|
||||
if include_resolved:
|
||||
url += "&include_resolved=true"
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class LLMProviderManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
@@ -26,13 +26,8 @@ class LLMProviderManager:
|
||||
personas: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
set_as_default: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestLLMProvider:
|
||||
email = "Unknown"
|
||||
if user_performing_action:
|
||||
email = user_performing_action.email
|
||||
|
||||
print(f"Seeding LLM Providers for {email}...")
|
||||
print(f"Seeding LLM Providers for {user_performing_action.email}...")
|
||||
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
@@ -60,11 +55,7 @@ class LLMProviderManager:
|
||||
llm_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
json=llm_provider.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
@@ -86,11 +77,7 @@ class LLMProviderManager:
|
||||
if set_as_default:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -99,30 +86,22 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
llm_provider: DATestLLMProvider,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[LLMProviderView]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
@@ -130,8 +109,8 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
llm_provider: DATestLLMProvider,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestPersona
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class PersonaManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
@@ -34,7 +34,6 @@ class PersonaManager:
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
display_priority: int | None = None,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
@@ -67,11 +66,7 @@ class PersonaManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request.model_dump(mode="json"),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
persona_data = response.json()
|
||||
@@ -100,6 +95,7 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
@@ -117,7 +113,6 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -151,11 +146,7 @@ class PersonaManager:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request.model_dump(mode="json"),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
updated_persona_data = response.json()
|
||||
@@ -187,15 +178,11 @@ class PersonaManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/persona",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**persona) for persona in response.json()]
|
||||
@@ -203,15 +190,11 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def get_one(
|
||||
persona_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**response.json())]
|
||||
@@ -219,7 +202,7 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_one(
|
||||
persona_id=persona.id,
|
||||
@@ -388,15 +371,11 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@@ -405,18 +384,14 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestPersonaLabel:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona/labels",
|
||||
json={
|
||||
"name": label.name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
@@ -425,15 +400,11 @@ class PersonaLabelManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestPersonaLabel]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/labels",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [DATestPersonaLabel(**label) for label in response.json()]
|
||||
@@ -441,18 +412,14 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def update(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestPersonaLabel:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
||||
json={
|
||||
"label_name": label.name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return label
|
||||
@@ -460,22 +427,18 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
all_labels = PersonaLabelManager.get_all(user_performing_action)
|
||||
for fetched_label in all_labels:
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.features.projects.models import UserProjectSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -20,7 +19,7 @@ class ProjectManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/create",
|
||||
params={"name": name},
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return UserProjectSnapshot.model_validate(response.json())
|
||||
@@ -32,7 +31,7 @@ class ProjectManager:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
|
||||
@@ -45,7 +44,7 @@ class ProjectManager:
|
||||
"""Delete a project via API."""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.status_code == 204
|
||||
|
||||
@@ -57,7 +56,7 @@ class ProjectManager:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
projects = [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
|
||||
@@ -66,16 +65,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def verify_files_unlinked(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Verify that all files have been unlinked from the project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/files/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return True
|
||||
@@ -87,16 +82,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def verify_chat_sessions_unlinked(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Verify that all chat sessions have been unlinked from the project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return True
|
||||
@@ -144,16 +135,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def get_project_files(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> List[UserFileSnapshot]:
|
||||
"""Get all files associated with a project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/files/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return []
|
||||
@@ -170,7 +157,7 @@ class ProjectManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}/instructions",
|
||||
json={"instructions": instructions},
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return (response.json() or {}).get("instructions") or ""
|
||||
|
||||
@@ -10,19 +10,18 @@ from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class QueryHistoryManager:
|
||||
@staticmethod
|
||||
def get_query_history_page(
|
||||
user_performing_action: DATestUser,
|
||||
page_num: int = 0,
|
||||
page_size: int = 10,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page_num,
|
||||
@@ -37,11 +36,7 @@ class QueryHistoryManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -53,24 +48,20 @@ class QueryHistoryManager:
|
||||
@staticmethod
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID | str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> ChatSessionSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ChatSessionSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_query_history_as_csv(
|
||||
user_performing_action: DATestUser,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> tuple[CaseInsensitiveDict[str], str]:
|
||||
query_params: dict[str, str | int] = {}
|
||||
if start_time:
|
||||
@@ -80,11 +71,7 @@ class QueryHistoryManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Optional
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -13,13 +12,9 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class SettingsManager:
|
||||
@staticmethod
|
||||
def get_settings(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.get(
|
||||
@@ -38,13 +33,9 @@ class SettingsManager:
|
||||
@staticmethod
|
||||
def update_settings(
|
||||
settings: DATestSettings,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
payload = settings.model_dump()
|
||||
@@ -65,7 +56,7 @@ class SettingsManager:
|
||||
@staticmethod
|
||||
def get_setting(
|
||||
key: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> Optional[Any]:
|
||||
settings, error = SettingsManager.get_settings(user_performing_action)
|
||||
if error:
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -26,15 +25,11 @@ def generate_auth_token() -> str:
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> AllUsersResponse:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -50,7 +45,8 @@ class TenantManager:
|
||||
|
||||
@staticmethod
|
||||
def verify_user_in_tenant(
|
||||
user: DATestUser, user_performing_action: DATestUser | None = None
|
||||
user: DATestUser,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
all_users = TenantManager.get_all_users(user_performing_action)
|
||||
for accepted_user in all_users.accepted:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestTool
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -9,15 +8,11 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
def list_tools(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestTool]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/tool",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
|
||||
@@ -7,6 +7,8 @@ import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import ANONYMOUS_USER_UUID
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -25,6 +27,23 @@ def build_email(name: str) -> str:
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def get_anonymous_user() -> DATestUser:
|
||||
"""Get a DATestUser representing the anonymous user.
|
||||
|
||||
Anonymous users are real users in the database with LIMITED role.
|
||||
They don't have login cookies - requests are made with GENERAL_HEADERS.
|
||||
The anonymous_user_enabled setting must be True for these requests to work.
|
||||
"""
|
||||
return DATestUser(
|
||||
id=ANONYMOUS_USER_UUID,
|
||||
email=ANONYMOUS_USER_EMAIL,
|
||||
password="",
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.LIMITED,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
@@ -227,12 +246,12 @@ class UserManager:
|
||||
|
||||
@staticmethod
|
||||
def get_user_page(
|
||||
user_performing_action: DATestUser,
|
||||
page_num: int = 0,
|
||||
page_size: int = 10,
|
||||
search_query: str | None = None,
|
||||
role_filter: list[UserRole] | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[FullUserSnapshot]:
|
||||
query_params: dict[str, str | list[str] | int] = {
|
||||
"page_num": page_num,
|
||||
@@ -247,11 +266,7 @@ class UserManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import requests
|
||||
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
@@ -14,10 +13,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
class UserGroupManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
user_ids: list[str] | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestUserGroup:
|
||||
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
|
||||
|
||||
@@ -29,11 +28,7 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
test_user_group = DATestUserGroup(
|
||||
@@ -47,31 +42,23 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
json=user_group.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -79,7 +66,7 @@ class UserGroupManager:
|
||||
def add_users(
|
||||
user_group: DATestUserGroup,
|
||||
user_ids: list[str],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestUserGroup:
|
||||
request = {
|
||||
"user_ids": user_ids,
|
||||
@@ -88,11 +75,7 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -107,8 +90,8 @@ class UserGroupManager:
|
||||
def set_curator_status(
|
||||
test_user_group: DATestUserGroup,
|
||||
user_to_set_as_curator: DATestUser,
|
||||
user_performing_action: DATestUser,
|
||||
is_curator: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
set_curator_request = {
|
||||
"user_id": user_to_set_as_curator.id,
|
||||
@@ -117,25 +100,17 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
|
||||
json=set_curator_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[UserGroup]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
@@ -143,8 +118,8 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
for fetched_user_group in all_user_groups:
|
||||
@@ -167,8 +142,8 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_performing_action: DATestUser,
|
||||
user_groups_to_check: list[DATestUserGroup] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
@@ -198,7 +173,7 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_groups_to_check: list[DATestUserGroup],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}
|
||||
|
||||
@@ -88,11 +88,8 @@ def reset() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None: # noqa: ARG001
|
||||
try:
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
return None
|
||||
def new_admin_user(reset: None) -> DATestUser: # noqa: ARG001
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -182,18 +179,18 @@ def reset_multitenant() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_generation_config(
|
||||
admin_user: DATestUser | None,
|
||||
admin_user: DATestUser,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a default image generation config for tests."""
|
||||
return ImageGenerationConfigManager.create(
|
||||
is_default=True,
|
||||
user_performing_action=admin_user,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,3 +60,44 @@ def test_me_endpoint_returns_authenticated_user_info(
|
||||
assert data.get("is_anonymous_user") is not True
|
||||
assert data["email"] == admin_user.email
|
||||
assert data["role"] == "admin"
|
||||
|
||||
|
||||
def test_anonymous_user_can_access_persona_when_enabled(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Verify that anonymous users can access limited endpoints when enabled."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(anonymous_user_enabled=True),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
anon_user = UserManager.get_anonymous_user()
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
headers=anon_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_user_denied_persona_when_disabled(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Verify that anonymous users cannot access endpoints when disabled."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(anonymous_user_enabled=False),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
anon_user = UserManager.get_anonymous_user()
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
headers=anon_user.headers,
|
||||
)
|
||||
# 403 is returned - BasicAuthenticationError uses HTTP 403 for all auth failures
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -11,8 +11,8 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
def _verify_index_attempt_pagination(
|
||||
cc_pair_id: int,
|
||||
index_attempt_ids: list[int],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_attempts: list[int] = []
|
||||
last_time_started = None # Track the last time_started seen
|
||||
|
||||
@@ -207,7 +207,9 @@ def test_mcp_search_respects_acl_filters(
|
||||
cc_pair_ids=[restricted_cc_pair.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync([user_group], user_performing_action=admin_user)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_performing_action=admin_user, user_groups_to_check=[user_group]
|
||||
)
|
||||
|
||||
restricted_doc_content = "MCP restricted knowledge base document"
|
||||
_seed_document_and_wait_for_indexing(
|
||||
|
||||
@@ -14,11 +14,11 @@ from tests.integration.tests.query_history.utils import (
|
||||
|
||||
def _verify_query_history_pagination(
|
||||
chat_sessions: list[DAQueryHistoryEntry],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_sessions: list[str] = []
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -59,7 +58,7 @@ def test_add_users_to_group_invalid_user(reset: None) -> None: # noqa: ARG001
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
|
||||
json={"user_ids": [invalid_user_id]},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -9,11 +9,11 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
# to verify that the pagination and filtering works as expected.
|
||||
def _verify_user_pagination(
|
||||
users: list[DATestUser],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
search_query: str | None = None,
|
||||
role_filter: list[UserRole] | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_users: list[FullUserSnapshot] = []
|
||||
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_enumerate_ad_groups_paginated,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_iter_graph_collection,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_normalize_email,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
AD_GROUP_ENUMERATION_THRESHOLD,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import GroupsResult
|
||||
|
||||
|
||||
MODULE = "ee.onyx.external_permissions.sharepoint.permission_utils"
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fake_token() -> str:
|
||||
return "fake-token"
|
||||
|
||||
|
||||
def _make_graph_page(
|
||||
items: list[dict[str, Any]],
|
||||
next_link: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
page: dict[str, Any] = {"value": items}
|
||||
if next_link:
|
||||
page["@odata.nextLink"] = next_link
|
||||
return page
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _normalize_email
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalize_email_strips_onmicrosoft() -> None:
|
||||
assert _normalize_email("user@contoso.onmicrosoft.com") == "user@contoso.com"
|
||||
|
||||
|
||||
def test_normalize_email_noop_for_normal_domain() -> None:
|
||||
assert _normalize_email("user@contoso.com") == "user@contoso.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_graph_collection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_iter_graph_collection_single_page(mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_graph_page([{"id": "1"}, {"id": "2"}])
|
||||
|
||||
items = list(_iter_graph_collection("https://graph/items", _fake_token))
|
||||
assert items == [{"id": "1"}, {"id": "2"}]
|
||||
mock_get.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_iter_graph_collection_multi_page(mock_get: MagicMock) -> None:
|
||||
mock_get.side_effect = [
|
||||
_make_graph_page([{"id": "1"}], next_link="https://graph/items?page=2"),
|
||||
_make_graph_page([{"id": "2"}]),
|
||||
]
|
||||
|
||||
items = list(_iter_graph_collection("https://graph/items", _fake_token))
|
||||
assert items == [{"id": "1"}, {"id": "2"}]
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_iter_graph_collection_empty(mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_graph_page([])
|
||||
assert list(_iter_graph_collection("https://graph/items", _fake_token)) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _enumerate_ad_groups_paginated
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_graph_get_for_enumeration(
|
||||
groups: list[dict[str, Any]],
|
||||
members_by_group: dict[str, list[dict[str, Any]]],
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Return a side_effect function for _graph_api_get that serves
|
||||
groups on the /groups URL and members on /groups/{id}/members URLs."""
|
||||
|
||||
def side_effect(
|
||||
url: str,
|
||||
get_access_token: Any, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
if "/members" in url:
|
||||
group_id = url.split("/groups/")[1].split("/members")[0]
|
||||
return _make_graph_page(members_by_group.get(group_id, []))
|
||||
return _make_graph_page(groups)
|
||||
|
||||
return side_effect # type: ignore[return-value]
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_enumerate_ad_groups_yields_groups(mock_get: MagicMock) -> None:
|
||||
groups = [
|
||||
{"id": "g1", "displayName": "Engineering"},
|
||||
{"id": "g2", "displayName": "Marketing"},
|
||||
]
|
||||
members = {
|
||||
"g1": [{"userPrincipalName": "alice@contoso.com"}],
|
||||
"g2": [{"mail": "bob@contoso.onmicrosoft.com"}],
|
||||
}
|
||||
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, members)
|
||||
|
||||
results = list(
|
||||
_enumerate_ad_groups_paginated(
|
||||
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
|
||||
)
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
eng = next(r for r in results if r.id == "Engineering_g1")
|
||||
assert eng.user_emails == ["alice@contoso.com"]
|
||||
mkt = next(r for r in results if r.id == "Marketing_g2")
|
||||
assert mkt.user_emails == ["bob@contoso.com"]
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_enumerate_ad_groups_skips_already_resolved(mock_get: MagicMock) -> None:
|
||||
groups = [{"id": "g1", "displayName": "Engineering"}]
|
||||
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
|
||||
|
||||
results = list(
|
||||
_enumerate_ad_groups_paginated(
|
||||
_fake_token,
|
||||
already_resolved={"Engineering_g1"},
|
||||
graph_api_base=GRAPH_API_BASE,
|
||||
)
|
||||
)
|
||||
assert results == []
|
||||
|
||||
|
||||
@patch(f"{MODULE}._graph_api_get")
|
||||
def test_enumerate_ad_groups_circuit_breaker(mock_get: MagicMock) -> None:
|
||||
"""Enumeration stops after AD_GROUP_ENUMERATION_THRESHOLD groups."""
|
||||
over_limit = AD_GROUP_ENUMERATION_THRESHOLD + 5
|
||||
groups = [{"id": f"g{i}", "displayName": f"Group{i}"} for i in range(over_limit)]
|
||||
mock_get.side_effect = _mock_graph_get_for_enumeration(groups, {})
|
||||
|
||||
results = list(
|
||||
_enumerate_ad_groups_paginated(
|
||||
_fake_token, already_resolved=set(), graph_api_base=GRAPH_API_BASE
|
||||
)
|
||||
)
|
||||
assert len(results) <= AD_GROUP_ENUMERATION_THRESHOLD
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_sharepoint_external_groups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _stub_role_assignment_resolution(
|
||||
groups_to_emails: dict[str, set[str]],
|
||||
) -> tuple[MagicMock, MagicMock]:
|
||||
"""Return (mock_sleep_and_retry, mock_recursive) pre-configured to
|
||||
simulate role-assignment group resolution."""
|
||||
mock_sleep = MagicMock()
|
||||
mock_recursive = MagicMock(
|
||||
return_value=GroupsResult(
|
||||
groups_to_emails=groups_to_emails,
|
||||
found_public_group=False,
|
||||
)
|
||||
)
|
||||
return mock_sleep, mock_recursive
|
||||
|
||||
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_default_skips_ad_enumeration(
|
||||
mock_sleep: MagicMock, mock_recursive: MagicMock # noqa: ARG001
|
||||
) -> None:
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
results = get_sharepoint_external_groups(
|
||||
client_context=MagicMock(),
|
||||
graph_client=MagicMock(),
|
||||
graph_api_base=GRAPH_API_BASE,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "SiteGroup_abc"
|
||||
assert results[0].user_emails == ["alice@contoso.com"]
|
||||
|
||||
|
||||
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_enumerate_all_includes_ad_groups(
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
mock_enum: MagicMock,
|
||||
) -> None:
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={"SiteGroup_abc": {"alice@contoso.com"}},
|
||||
found_public_group=False,
|
||||
)
|
||||
mock_enum.return_value = [
|
||||
ExternalUserGroup(id="ADGroup_xyz", user_emails=["bob@contoso.com"]),
|
||||
]
|
||||
|
||||
results = get_sharepoint_external_groups(
|
||||
client_context=MagicMock(),
|
||||
graph_client=MagicMock(),
|
||||
get_access_token=_fake_token,
|
||||
enumerate_all_ad_groups=True,
|
||||
graph_api_base=GRAPH_API_BASE,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
ids = {r.id for r in results}
|
||||
assert ids == {"SiteGroup_abc", "ADGroup_xyz"}
|
||||
mock_enum.assert_called_once()
|
||||
|
||||
|
||||
@patch(f"{MODULE}._enumerate_ad_groups_paginated")
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_enumerate_all_without_token_skips(
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
mock_enum: MagicMock,
|
||||
) -> None:
|
||||
"""Even if enumerate_all_ad_groups=True, no token means skip."""
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
results = get_sharepoint_external_groups(
|
||||
client_context=MagicMock(),
|
||||
graph_client=MagicMock(),
|
||||
get_access_token=None,
|
||||
enumerate_all_ad_groups=True,
|
||||
graph_api_base=GRAPH_API_BASE,
|
||||
)
|
||||
|
||||
assert results == []
|
||||
mock_enum.assert_not_called()
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.chat.chat_utils import _build_tool_call_response_history_message
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
|
||||
class TestGetCustomAgentPrompt:
|
||||
@@ -150,3 +152,21 @@ class TestGetCustomAgentPrompt:
|
||||
|
||||
# Should return None because replace_base_system_prompt=True
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBuildToolCallResponseHistoryMessage:
|
||||
def test_image_tool_uses_generated_images(self) -> None:
|
||||
message = _build_tool_call_response_history_message(
|
||||
tool_name="generate_image",
|
||||
generated_images=[{"file_id": "img-1", "revised_prompt": "p1"}],
|
||||
tool_call_response=None,
|
||||
)
|
||||
assert message == '[{"file_id": "img-1", "revised_prompt": "p1"}]'
|
||||
|
||||
def test_non_image_tool_uses_placeholder(self) -> None:
|
||||
message = _build_tool_call_response_history_message(
|
||||
tool_name="web_search",
|
||||
generated_images=None,
|
||||
tool_call_response='{"raw":"value"}',
|
||||
)
|
||||
assert message == TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
@@ -0,0 +1,459 @@
|
||||
"""Tests for per-page delta checkpointing in the SharePoint connector (P1-1).
|
||||
|
||||
Validates that:
|
||||
- Delta drives process one page per _load_from_checkpoint call
|
||||
- Checkpoints persist the delta next_link for resumption
|
||||
- Crash + resume skips already-processed pages
|
||||
- BFS (folder-scoped) drives process all items in one call
|
||||
- 410 Gone triggers a full-resync URL in the checkpoint
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.sharepoint.connector import DriveItemData
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnectorCheckpoint
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SITE_URL = "https://example.sharepoint.com/sites/sample"
|
||||
DRIVE_WEB_URL = f"{SITE_URL}/Shared Documents"
|
||||
DRIVE_ID = "fake-drive-id"
|
||||
|
||||
# Use a start time in the future so delta URLs include a timestamp token
|
||||
_START_TS = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
_END_TS = datetime(2026, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
# For BFS tests we use epoch so no token is generated
|
||||
_EPOCH_START: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_item(item_id: str, name: str = "doc.pdf") -> DriveItemData:
|
||||
return DriveItemData(
|
||||
id=item_id,
|
||||
name=name,
|
||||
web_url=f"{SITE_URL}/{name}",
|
||||
parent_reference_path="/drives/d1/root:",
|
||||
drive_id=DRIVE_ID,
|
||||
)
|
||||
|
||||
|
||||
def _make_document(item: DriveItemData) -> Document:
|
||||
return Document(
|
||||
id=item.id,
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
semantic_identifier=item.name,
|
||||
metadata={},
|
||||
sections=[TextSection(link=item.web_url, text="content")],
|
||||
)
|
||||
|
||||
|
||||
def _consume_generator(
|
||||
gen: Generator[Any, None, SharepointConnectorCheckpoint],
|
||||
) -> tuple[list[Any], SharepointConnectorCheckpoint]:
|
||||
"""Exhaust a _load_from_checkpoint generator.
|
||||
|
||||
Returns (yielded_items, returned_checkpoint).
|
||||
"""
|
||||
yielded: list[Any] = []
|
||||
try:
|
||||
while True:
|
||||
yielded.append(next(gen))
|
||||
except StopIteration as e:
|
||||
return yielded, e.value
|
||||
|
||||
|
||||
def _docs_from(yielded: list[Any]) -> list[Document]:
|
||||
return [y for y in yielded if isinstance(y, Document)]
|
||||
|
||||
|
||||
def _failures_from(yielded: list[Any]) -> list[ConnectorFailure]:
|
||||
return [y for y in yielded if isinstance(y, ConnectorFailure)]
|
||||
|
||||
|
||||
def _build_ready_checkpoint(
|
||||
drive_names: list[str] | None = None,
|
||||
folder_path: str | None = None,
|
||||
) -> SharepointConnectorCheckpoint:
|
||||
"""Checkpoint ready for Phase 3 (sites initialised, drives queued)."""
|
||||
cp = SharepointConnectorCheckpoint(has_more=True)
|
||||
cp.cached_site_descriptors = deque()
|
||||
cp.current_site_descriptor = SiteDescriptor(
|
||||
url=SITE_URL,
|
||||
drive_name=None,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
cp.cached_drive_names = deque(drive_names or ["Documents"])
|
||||
cp.process_site_pages = False
|
||||
return cp
|
||||
|
||||
|
||||
def _setup_connector(monkeypatch: pytest.MonkeyPatch) -> SharepointConnector:
|
||||
"""Create a connector with common methods mocked."""
|
||||
connector = SharepointConnector()
|
||||
connector._graph_client = object()
|
||||
connector.include_site_pages = False
|
||||
|
||||
def fake_resolve_drive(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
site_descriptor: SiteDescriptor, # noqa: ARG001
|
||||
drive_name: str, # noqa: ARG001
|
||||
) -> tuple[str, str | None]:
|
||||
return (DRIVE_ID, DRIVE_WEB_URL)
|
||||
|
||||
def fake_get_access_token(self: SharepointConnector) -> str: # noqa: ARG001
|
||||
return "fake-access-token"
|
||||
|
||||
monkeypatch.setattr(SharepointConnector, "_resolve_drive", fake_resolve_drive)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_get_graph_access_token", fake_get_access_token
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_convert(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Replace _convert_driveitem_to_document_with_permissions with a trivial stub."""
|
||||
|
||||
def fake_convert(
|
||||
driveitem: DriveItemData,
|
||||
drive_name: str, # noqa: ARG001
|
||||
ctx: Any = None, # noqa: ARG001
|
||||
graph_client: Any = None, # noqa: ARG001
|
||||
graph_api_base: str = "", # noqa: ARG001
|
||||
include_permissions: bool = False, # noqa: ARG001
|
||||
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
|
||||
access_token: str | None = None, # noqa: ARG001
|
||||
) -> Document:
|
||||
return _make_document(driveitem)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.sharepoint.connector"
|
||||
"._convert_driveitem_to_document_with_permissions",
|
||||
fake_convert,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeltaPerPageCheckpointing:
|
||||
"""Delta (non-folder-scoped) drives should process one API page per
|
||||
_load_from_checkpoint call, persisting the next-link in between."""
|
||||
|
||||
def test_processes_one_page_per_cycle(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
items_p1 = [_make_item("a"), _make_item("b")]
|
||||
items_p2 = [_make_item("c")]
|
||||
items_p3 = [_make_item("d"), _make_item("e")]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return items_p1, "https://graph.microsoft.com/next2"
|
||||
if call_count == 2:
|
||||
return items_p2, "https://graph.microsoft.com/next3"
|
||||
return items_p3, None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Call 1: Phase 3a inits drive, Phase 3b processes page 1
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 2
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next2"
|
||||
)
|
||||
assert checkpoint.current_drive_id == DRIVE_ID
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
# Call 2: Phase 3b processes page 2
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next3"
|
||||
)
|
||||
|
||||
# Call 3: Phase 3b processes page 3 (last)
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 2
|
||||
assert checkpoint.current_drive_name is None
|
||||
assert checkpoint.current_drive_id is None
|
||||
assert checkpoint.current_drive_delta_next_link is None
|
||||
|
||||
def test_resume_after_simulated_crash(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Serialise the checkpoint after page 1, create a fresh connector,
|
||||
and verify page 2 is fetched using the saved next-link."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
captured_urls: list[str] = []
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str,
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
captured_urls.append(page_url)
|
||||
if call_count == 1:
|
||||
return [_make_item("a")], "https://graph.microsoft.com/next2"
|
||||
return [_make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# Process page 1
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
_, checkpoint = _consume_generator(gen)
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next2"
|
||||
)
|
||||
|
||||
# --- Simulate crash: serialise & deserialise checkpoint ---
|
||||
saved_json = checkpoint.model_dump_json()
|
||||
restored = SharepointConnectorCheckpoint.model_validate_json(saved_json)
|
||||
|
||||
# New connector instance (as if process restarted)
|
||||
connector2 = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# Resume — should pick up from next2
|
||||
gen = connector2._load_from_checkpoint(
|
||||
_START_TS, _END_TS, restored, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "b"
|
||||
assert captured_urls[-1] == "https://graph.microsoft.com/next2"
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
def test_single_page_drive_completes_in_one_cycle(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A drive with only one delta page should init + process + clear
|
||||
in a single _load_from_checkpoint call."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("only")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestBfsPathNoCheckpointing:
|
||||
"""Folder-scoped (BFS) drives should process all items in one call
|
||||
because the BFS queue cannot be cheaply serialised."""
|
||||
|
||||
def test_bfs_processes_all_at_once(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
items = [_make_item("x"), _make_item("y"), _make_item("z")]
|
||||
|
||||
def fake_iter_paged(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
folder_path: str | None = None, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> Generator[DriveItemData, None, None]:
|
||||
yield from items
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_iter_drive_items_paged", fake_iter_paged
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint(folder_path="Engineering/Docs")
|
||||
gen = connector._load_from_checkpoint(
|
||||
_EPOCH_START, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
assert len(_docs_from(yielded)) == 3
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestDelta410GoneResync:
|
||||
"""On 410 Gone the checkpoint should be updated with a full-resync URL
|
||||
and the next cycle should re-enumerate from scratch."""
|
||||
|
||||
def test_410_stores_full_resync_url(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str,
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200,
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Simulate the 410 handler returning a full-resync URL
|
||||
full_url = (
|
||||
f"https://graph.microsoft.com/v1.0/drives/{drive_id}"
|
||||
f"/root/delta?$top={page_size}"
|
||||
)
|
||||
return [], full_url
|
||||
return [_make_item("recovered")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Call 1: 3a inits, 3b gets empty page + resync URL
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 0
|
||||
assert checkpoint.current_drive_delta_next_link is not None
|
||||
assert "token=" not in checkpoint.current_drive_delta_next_link
|
||||
|
||||
# Call 2: processes the full resync
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "recovered"
|
||||
assert checkpoint.current_drive_name is None
|
||||
|
||||
|
||||
class TestDeltaPageFetchFailure:
|
||||
"""If _fetch_one_delta_page raises, the drive should be abandoned with a
|
||||
ConnectorFailure and the checkpoint should be cleared for the next drive."""
|
||||
|
||||
def test_page_fetch_error_yields_failure_and_clears_state(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
raise RuntimeError("network blip")
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
failures = _failures_from(yielded)
|
||||
assert len(failures) == 1
|
||||
assert "network blip" in failures[0].failure_message
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
@@ -192,20 +192,22 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
|
||||
"https://example.sharepoint.com/sites/sample/Documents",
|
||||
)
|
||||
|
||||
def fake_get_drive_items(
|
||||
def fake_fetch_one_delta_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
site_descriptor: SiteDescriptor, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None, # noqa: ARG001
|
||||
end: datetime | None, # noqa: ARG001
|
||||
) -> Generator[DriveItemData, None, None]:
|
||||
yield sample_item
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [sample_item], None
|
||||
|
||||
def fake_convert(
|
||||
driveitem: DriveItemData, # noqa: ARG001
|
||||
drive_name: str,
|
||||
ctx: Any, # noqa: ARG001
|
||||
graph_client: Any, # noqa: ARG001
|
||||
graph_api_base: str, # noqa: ARG001
|
||||
include_permissions: bool, # noqa: ARG001
|
||||
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
|
||||
access_token: str | None = None, # noqa: ARG001
|
||||
@@ -229,8 +231,8 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector,
|
||||
"_get_drive_items_for_drive_id",
|
||||
fake_get_drive_items,
|
||||
"_fetch_one_delta_page",
|
||||
fake_fetch_one_delta_page,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.sharepoint.connector._convert_driveitem_to_document_with_permissions",
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_runner import _extract_image_file_ids_from_tool_response_message
|
||||
from onyx.tools.tool_runner import _extract_recent_generated_image_file_ids
|
||||
from onyx.tools.tool_runner import _merge_tool_calls
|
||||
|
||||
|
||||
@@ -307,3 +312,65 @@ class TestMergeToolCalls:
|
||||
assert len(result) == 1
|
||||
# String should be converted to list item
|
||||
assert result[0].tool_args["queries"] == ["single_query", "q2"]
|
||||
|
||||
|
||||
class TestImageHistoryExtraction:
|
||||
def test_extracts_image_file_ids_from_json_response(self) -> None:
|
||||
msg = (
|
||||
'[{"file_id":"img-1","revised_prompt":"v1"},'
|
||||
'{"file_id":"img-2","revised_prompt":"v2"}]'
|
||||
)
|
||||
assert _extract_image_file_ids_from_tool_response_message(msg) == [
|
||||
"img-1",
|
||||
"img-2",
|
||||
]
|
||||
|
||||
def test_extracts_recent_generated_image_ids_from_history(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="generate_image",
|
||||
tool_arguments={"prompt": "test"},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == ["img-1"]
|
||||
|
||||
def test_ignores_non_image_tool_responses(self) -> None:
|
||||
history = [
|
||||
ChatMessageSimple(
|
||||
message="",
|
||||
token_count=1,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
tool_calls=[
|
||||
ToolCallSimple(
|
||||
tool_call_id="call_1",
|
||||
tool_name="web_search",
|
||||
tool_arguments={"queries": ["q"]},
|
||||
token_count=1,
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessageSimple(
|
||||
message='[{"file_id":"img-1","revised_prompt":"r1"}]',
|
||||
token_count=1,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id="call_1",
|
||||
),
|
||||
]
|
||||
|
||||
assert _extract_recent_generated_image_file_ids(history) == []
|
||||
|
||||
@@ -10,7 +10,7 @@ from onyx.tools.utils import explicit_tool_calling_supported
|
||||
(LlmProviderNames.ANTHROPIC, "claude-4-sonnet-20250514", True),
|
||||
(
|
||||
"another-provider",
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-haiku-4-5-20251001",
|
||||
True,
|
||||
),
|
||||
(
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import onyx.tools.tool_implementations.open_url.onyx_web_crawler as crawler_module
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
|
||||
|
||||
|
||||
@@ -181,3 +191,163 @@ def test_fetch_url_html_within_size_limit(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
assert "hello world" in result.full_content
|
||||
assert result.scrape_successful is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for parallel / failure-isolation / timeout tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
content: bytes = b"<html><body>Hello</body></html>",
|
||||
content_type: str = "text/html",
|
||||
delay: float = 0.0,
|
||||
) -> MagicMock:
|
||||
"""Create a mock response that behaves like a requests.Response."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.headers = {"Content-Type": content_type}
|
||||
|
||||
if delay:
|
||||
original_content = content
|
||||
|
||||
@property # type: ignore[misc]
|
||||
def _delayed_content(_self: object) -> bytes:
|
||||
time.sleep(delay)
|
||||
return original_content
|
||||
|
||||
type(resp).content = _delayed_content
|
||||
else:
|
||||
resp.content = content
|
||||
|
||||
resp.apparent_encoding = None
|
||||
resp.encoding = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class TestParallelExecution:
|
||||
"""Verify that contents() fetches URLs in parallel."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_multiple_urls_fetched_concurrently(self, mock_get: MagicMock) -> None:
|
||||
"""With a per-URL delay, parallel execution should be much faster than sequential."""
|
||||
per_url_delay = 0.3
|
||||
num_urls = 5
|
||||
urls = [f"http://example.com/page{i}" for i in range(num_urls)]
|
||||
|
||||
mock_get.return_value = _make_mock_response(delay=per_url_delay)
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
start = time.monotonic()
|
||||
results = crawler.contents(urls)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Sequential would take ~1.5s; parallel should be well under that
|
||||
assert elapsed < per_url_delay * num_urls * 0.7
|
||||
assert len(results) == num_urls
|
||||
assert all(r.scrape_successful for r in results)
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_empty_urls_returns_empty(self, mock_get: MagicMock) -> None:
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents([])
|
||||
assert results == []
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_single_url(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://example.com"])
|
||||
assert len(results) == 1
|
||||
assert results[0].scrape_successful
|
||||
|
||||
|
||||
class TestFailureIsolation:
|
||||
"""Verify that one URL failure doesn't affect others in the batch."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_one_failure_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
good_resp = _make_mock_response()
|
||||
bad_resp = _make_mock_response(status_code=500)
|
||||
|
||||
# First and third URLs succeed, second fails
|
||||
mock_get.side_effect = [good_resp, bad_resp, good_resp]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://a.com", "http://b.com", "http://c.com"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_exception_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
good_resp = _make_mock_response()
|
||||
|
||||
# Second URL raises an exception
|
||||
mock_get.side_effect = [
|
||||
good_resp,
|
||||
RuntimeError("connection reset"),
|
||||
_make_mock_response(),
|
||||
]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://a.com", "http://b.com", "http://c.com"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_ssrf_exception_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
from onyx.utils.url import SSRFException
|
||||
|
||||
good_resp = _make_mock_response()
|
||||
mock_get.side_effect = [
|
||||
good_resp,
|
||||
SSRFException("blocked"),
|
||||
_make_mock_response(),
|
||||
]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(
|
||||
["http://a.com", "http://internal.local", "http://c.com"]
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
|
||||
class TestTupleTimeout:
|
||||
"""Verify that separate connect and read timeouts are passed correctly."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_default_tuple_timeout(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
crawler.contents(["http://example.com"])
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs.kwargs["timeout"] == (
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_custom_tuple_timeout(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
|
||||
crawler = OnyxWebCrawler(timeout_seconds=30, connect_timeout_seconds=3)
|
||||
crawler.contents(["http://example.com"])
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs.kwargs["timeout"] == (3, 30)
|
||||
|
||||
@@ -291,7 +291,7 @@ class TestSsrfSafeGet:
|
||||
assert call_args[1]["headers"]["User-Agent"] == "TestBot/1.0"
|
||||
|
||||
def test_passes_timeout(self) -> None:
|
||||
"""Test that timeout is passed through."""
|
||||
"""Test that timeout is passed through, including tuple form."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_redirect = False
|
||||
|
||||
@@ -301,7 +301,7 @@ class TestSsrfSafeGet:
|
||||
with patch("onyx.utils.url.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
ssrf_safe_get("http://example.com/", timeout=30)
|
||||
ssrf_safe_get("http://example.com/", timeout=(5, 15))
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == 30
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
@@ -147,7 +147,7 @@ Add clear comments:
|
||||
|
||||
## Trunk-based development and feature flags
|
||||
|
||||
- **PRs should contain no more than 500 lines of real change**
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches—they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
- Large features should be merged in small, shippable increments behind a flag.
|
||||
@@ -155,3 +155,11 @@ Add clear comments:
|
||||
- **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly.
|
||||
- **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic.
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
---
|
||||
|
||||
## Misc
|
||||
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
|
||||
3
ct.yaml
3
ct.yaml
@@ -12,7 +12,8 @@ chart-repos:
|
||||
- postgresql=https://cloudnative-pg.github.io/charts
|
||||
- redis=https://ot-container-kit.github.io/helm-charts
|
||||
- minio=https://charts.min.io/
|
||||
|
||||
- code-interpreter=https://onyx-dot-app.github.io/python-sandbox/
|
||||
|
||||
# have seen postgres take 10 min to pull ... so 15 min seems like a good timeout?
|
||||
helm-extra-args: --debug --timeout 900s
|
||||
|
||||
|
||||
@@ -4,24 +4,6 @@ log_format custom_main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||
'"$http_user_agent" "$http_x_forwarded_for" '
|
||||
'rt=$request_time';
|
||||
|
||||
# Map X-Forwarded-Proto or fallback to $scheme
|
||||
map $http_x_forwarded_proto $forwarded_proto {
|
||||
default $http_x_forwarded_proto;
|
||||
"" $scheme;
|
||||
}
|
||||
|
||||
# Map X-Forwarded-Host or fallback to $host
|
||||
map $http_x_forwarded_host $forwarded_host {
|
||||
default $http_x_forwarded_host;
|
||||
"" $host;
|
||||
}
|
||||
|
||||
# Map X-Forwarded-Port or fallback to server port
|
||||
map $http_x_forwarded_port $forwarded_port {
|
||||
default $http_x_forwarded_port;
|
||||
"" $server_port;
|
||||
}
|
||||
|
||||
upstream api_server {
|
||||
# fail_timeout=0 means we always retry an upstream even if it failed
|
||||
# to return a good HTTP response
|
||||
@@ -59,9 +41,10 @@ server {
|
||||
# misc headers
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $forwarded_proto;
|
||||
proxy_set_header X-Forwarded-Host $forwarded_host;
|
||||
proxy_set_header X-Forwarded-Port $forwarded_port;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers
|
||||
@@ -78,9 +61,10 @@ server {
|
||||
# misc headers
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $forwarded_proto;
|
||||
proxy_set_header X-Forwarded-Host $forwarded_host;
|
||||
proxy_set_header X-Forwarded-Port $forwarded_port;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
|
||||
@@ -4,24 +4,6 @@ log_format custom_main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||
'"$http_user_agent" "$http_x_forwarded_for" '
|
||||
'rt=$request_time';
|
||||
|
||||
# Map X-Forwarded-Proto or fallback to $scheme
|
||||
map $http_x_forwarded_proto $forwarded_proto {
|
||||
default $http_x_forwarded_proto;
|
||||
"" $scheme;
|
||||
}
|
||||
|
||||
# Map X-Forwarded-Host or fallback to $host
|
||||
map $http_x_forwarded_host $forwarded_host {
|
||||
default $http_x_forwarded_host;
|
||||
"" $host;
|
||||
}
|
||||
|
||||
# Map X-Forwarded-Port or fallback to server port
|
||||
map $http_x_forwarded_port $forwarded_port {
|
||||
default $http_x_forwarded_port;
|
||||
"" $server_port;
|
||||
}
|
||||
|
||||
upstream api_server {
|
||||
# fail_timeout=0 means we always retry an upstream even if it failed
|
||||
# to return a good HTTP response
|
||||
@@ -59,9 +41,10 @@ server {
|
||||
# misc headers
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $forwarded_proto;
|
||||
proxy_set_header X-Forwarded-Host $forwarded_host;
|
||||
proxy_set_header X-Forwarded-Port $forwarded_port;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# need to use 1.1 to support chunked transfers
|
||||
@@ -83,9 +66,10 @@ server {
|
||||
# misc headers
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $forwarded_proto;
|
||||
proxy_set_header X-Forwarded-Host $forwarded_host;
|
||||
proxy_set_header X-Forwarded-Port $forwarded_port;
|
||||
# don't trust client-supplied X-Forwarded-* headers — use nginx's own values
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
proxy_http_version 1.1;
|
||||
|
||||
@@ -35,6 +35,10 @@ services:
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
opensearch:
|
||||
ports:
|
||||
- "9200:9200"
|
||||
|
||||
inference_model_server:
|
||||
ports:
|
||||
- "9000:9000"
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
services:
|
||||
opensearch:
|
||||
image: opensearchproject/opensearch:3.4.0
|
||||
environment:
|
||||
# We need discovery.type=single-node so that OpenSearch doesn't try
|
||||
# forming a cluster and waiting for other nodes to become live.
|
||||
- discovery.type=single-node
|
||||
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_ADMIN_PASSWORD:?OPENSEARCH_ADMIN_PASSWORD must be set}
|
||||
ports:
|
||||
- "9200:9200"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "curl -skf https://localhost:9200 -u admin:${OPENSEARCH_ADMIN_PASSWORD:?OPENSEARCH_ADMIN_PASSWORD must be set}"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
@@ -76,6 +76,9 @@ services:
|
||||
- FILE_STORE_BACKEND=${FILE_STORE_BACKEND:-s3}
|
||||
- POSTGRES_HOST=${POSTGRES_HOST:-relational_db}
|
||||
- VESPA_HOST=${VESPA_HOST:-index}
|
||||
- OPENSEARCH_HOST=${OPENSEARCH_HOST:-opensearch}
|
||||
- OPENSEARCH_ADMIN_PASSWORD=${OPENSEARCH_ADMIN_PASSWORD:-StrongPassword123!}
|
||||
- ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=${OPENSEARCH_FOR_ONYX_ENABLED:-false}
|
||||
- REDIS_HOST=${REDIS_HOST:-cache}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
@@ -147,6 +150,9 @@ services:
|
||||
- FILE_STORE_BACKEND=${FILE_STORE_BACKEND:-s3}
|
||||
- POSTGRES_HOST=${POSTGRES_HOST:-relational_db}
|
||||
- VESPA_HOST=${VESPA_HOST:-index}
|
||||
- OPENSEARCH_HOST=${OPENSEARCH_HOST:-opensearch}
|
||||
- OPENSEARCH_ADMIN_PASSWORD=${OPENSEARCH_ADMIN_PASSWORD:-StrongPassword123!}
|
||||
- ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=${OPENSEARCH_FOR_ONYX_ENABLED:-false}
|
||||
- REDIS_HOST=${REDIS_HOST:-cache}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
@@ -395,6 +401,45 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
opensearch:
|
||||
image: opensearchproject/opensearch:3.4.0
|
||||
restart: unless-stopped
|
||||
# Controls whether this service runs. In order to enable it, add
|
||||
# opensearch-enabled to COMPOSE_PROFILES in the environment for this
|
||||
# docker-compose.
|
||||
profiles: ["opensearch-enabled"]
|
||||
environment:
|
||||
# We need discovery.type=single-node so that OpenSearch doesn't try
|
||||
# forming a cluster and waiting for other nodes to become live.
|
||||
- discovery.type=single-node
|
||||
- OPENSEARCH_INITIAL_ADMIN_PASSWORD=${OPENSEARCH_ADMIN_PASSWORD:-StrongPassword123!}
|
||||
# This and the JVM config below come from the example in https://docs.opensearch.org/latest/install-and-configure/install-opensearch/docker/
|
||||
# We do this to avoid unstable performance from page swaps.
|
||||
- bootstrap.memory_lock=true # Disable JVM heap memory swapping.
|
||||
# Java heap should be ~50% of memory limit. For now we assume a limit of
|
||||
# 2g although in practice the container can request more than this.
|
||||
# See https://opster.com/guides/opensearch/opensearch-basics/opensearch-heap-size-usage-and-jvm-garbage-collection/
|
||||
# Xms is the starting size, Xmx is the maximum size. These should be the
|
||||
# same.
|
||||
- "OPENSEARCH_JAVA_OPTS=-Xms1g -Xmx1g"
|
||||
volumes:
|
||||
- opensearch-data:/usr/share/opensearch/data
|
||||
# These come from the example in https://docs.opensearch.org/latest/install-and-configure/install-opensearch/docker/
|
||||
ulimits:
|
||||
# Similarly to bootstrap.memory_lock, we don't want to impose limits on
|
||||
# how much memory a process can lock from being swapped.
|
||||
memlock:
|
||||
soft: -1 # Set memlock to unlimited (no soft or hard limit).
|
||||
hard: -1
|
||||
nofile:
|
||||
soft: 65536 # Maximum number of open files for the opensearch user - set to at least 65536.
|
||||
hard: 65536
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.25.5-alpine
|
||||
restart: unless-stopped
|
||||
@@ -535,3 +580,5 @@ volumes:
|
||||
indexing_model_server_logs:
|
||||
# Shared volume for persistent document storage (Craft file-system mode)
|
||||
file-system:
|
||||
# Persistent data for OpenSearch.
|
||||
opensearch-data:
|
||||
|
||||
@@ -64,9 +64,13 @@ POSTGRES_PASSWORD=password
|
||||
|
||||
## File Store Backend: "s3" (default, uses MinIO) or "postgres" (no extra services needed)
|
||||
## COMPOSE_PROFILES activates the MinIO service. To use PostgreSQL file storage instead,
|
||||
## comment out COMPOSE_PROFILES and set FILE_STORE_BACKEND=postgres.
|
||||
## remove s3-filestore from COMPOSE_PROFILES and set FILE_STORE_BACKEND=postgres.
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
FILE_STORE_BACKEND=s3
|
||||
## Settings for enabling OpenSearch. Uncomment these and comment out
|
||||
## COMPOSE_PROFILES above.
|
||||
# COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
# OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
|
||||
## MinIO/S3 Configuration (only needed when FILE_STORE_BACKEND=s3)
|
||||
S3_ENDPOINT_URL=http://minio:9000
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user