mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 01:22:45 +00:00
Compare commits
5 Commits
v3.0.3
...
auth-clean
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0c9d36692 | ||
|
|
e03bf2a6a3 | ||
|
|
7c8c7c9d91 | ||
|
|
89d8521f37 | ||
|
|
24a0e08ee2 |
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -20,7 +21,13 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -143,10 +144,22 @@ def is_user_admin(user: User) -> bool:
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
raise ValueError(
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -85,19 +85,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Defaulting to 'basic'. Please update your configuration. "
|
||||
"Your existing data will be migrated automatically."
|
||||
)
|
||||
_auth_type_str = AuthType.BASIC.value
|
||||
try:
|
||||
# Silently default to basic - warnings/errors logged in verify_auth_setting()
|
||||
# which only runs on app startup, not during migrations/scripts
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
|
||||
AUTH_TYPE = AuthType(_auth_type_str)
|
||||
except ValueError:
|
||||
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
|
||||
else:
|
||||
AUTH_TYPE = AuthType.BASIC
|
||||
|
||||
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
|
||||
|
||||
@@ -1831,7 +1831,7 @@ def get_connector_by_id(
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
@@ -1844,7 +1844,7 @@ def submit_connector_request(
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email if user else None
|
||||
user_email = user.email
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
|
||||
@@ -57,9 +57,6 @@ def list_messages(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageListResponse:
|
||||
"""Get all messages for a build session."""
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
messages = session_manager.list_messages(session_id, user.id)
|
||||
|
||||
@@ -43,18 +43,14 @@ def _require_opensearch(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(
|
||||
user: User | None, db_session: Session
|
||||
) -> tuple[str | None, list[str]]:
|
||||
if not user:
|
||||
return None, []
|
||||
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
|
||||
return user.email, get_user_external_group_ids(db_session, user)
|
||||
|
||||
|
||||
@router.get(HIERARCHY_NODES_LIST_PATH)
|
||||
def list_accessible_hierarchy_nodes(
|
||||
source: DocumentSource,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodesResponse:
|
||||
_require_opensearch(db_session)
|
||||
@@ -81,7 +77,7 @@ def list_accessible_hierarchy_nodes(
|
||||
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
|
||||
def list_accessible_hierarchy_node_documents(
|
||||
documents_request: HierarchyNodeDocumentsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodeDocumentsResponse:
|
||||
_require_opensearch(db_session)
|
||||
|
||||
@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
|
||||
@router.get("/servers", response_model=MCPServersResponse)
|
||||
def get_mcp_servers_for_user(
|
||||
db: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
) -> MCPServersResponse:
|
||||
"""List all MCP servers for use in agent configuration and chat UI.
|
||||
|
||||
|
||||
@@ -13,9 +13,9 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class APIKeyManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
api_key_role: UserRole = UserRole.ADMIN,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestAPIKey:
|
||||
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
|
||||
api_key_request = APIKeyArgs(
|
||||
@@ -25,11 +25,7 @@ class APIKeyManager:
|
||||
api_key_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
json=api_key_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
api_key = api_key_response.json()
|
||||
@@ -48,29 +44,21 @@ class APIKeyManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
api_key: DATestAPIKey,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
api_key_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestAPIKey]:
|
||||
api_key_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
|
||||
@@ -78,8 +66,8 @@ class APIKeyManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
api_key: DATestAPIKey,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_keys = APIKeyManager.get_all(
|
||||
user_performing_action=user_performing_action
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.server.documents.models import DocumentSource
|
||||
from onyx.server.documents.models import DocumentSyncStatus
|
||||
from tests.integration.common_utils.config import api_config
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
@@ -28,10 +27,10 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
def _cc_pair_creator(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
|
||||
|
||||
@@ -40,17 +39,12 @@ def _cc_pair_creator(
|
||||
connector_credential_pair_metadata = api.ConnectorCredentialPairMetadata(
|
||||
name=name, access_type=access_type, groups=groups or []
|
||||
)
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
api_response: api.StatusResponseInt = (
|
||||
api_instance.associate_credential_to_connector(
|
||||
connector_id,
|
||||
credential_id,
|
||||
connector_credential_pair_metadata,
|
||||
_headers=headers,
|
||||
_headers=user_performing_action.headers,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -67,6 +61,7 @@ def _cc_pair_creator(
|
||||
class CCPairManager:
|
||||
@staticmethod
|
||||
def create_from_scratch(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
@@ -74,26 +69,25 @@ class CCPairManager:
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestCCPair:
|
||||
connector = ConnectorManager.create(
|
||||
user_performing_action=user_performing_action,
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config,
|
||||
access_type=access_type,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
refresh_freq=refresh_freq,
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
user_performing_action=user_performing_action,
|
||||
credential_json=credential_json,
|
||||
name=name,
|
||||
source=source,
|
||||
curator_public=(access_type == AccessType.PUBLIC),
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector.id,
|
||||
@@ -109,10 +103,10 @@ class CCPairManager:
|
||||
def create(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector_id,
|
||||
@@ -127,39 +121,31 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def pause_cc_pair(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "PAUSED"},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def unpause_cc_pair(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "ACTIVE"},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -168,26 +154,18 @@ class CCPairManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
|
||||
json=cc_pair_identifier.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_single(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> CCPairFullInfo | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
cc_pair_json = response.json()
|
||||
@@ -196,15 +174,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_indexing_status_by_id(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> ConnectorIndexingStatusLite | None:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
json={"get_all_connectors": True},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -219,15 +193,11 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def get_indexing_statuses(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[ConnectorIndexingStatusLite]:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
json={"get_all_connectors": True},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -241,15 +211,11 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def get_connector_statuses(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[ConnectorStatus]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ConnectorStatus(**status) for status in response.json()]
|
||||
@@ -257,8 +223,8 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_cc_pairs = CCPairManager.get_connector_statuses(user_performing_action)
|
||||
for retrieved_cc_pair in all_cc_pairs:
|
||||
@@ -285,7 +251,7 @@ class CCPairManager:
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
from_beginning: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
@@ -295,19 +261,15 @@ class CCPairManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
json=body,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_inactive(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
@@ -342,9 +304,9 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def wait_for_indexing_in_progress(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
num_docs: int = 16,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
@@ -393,8 +355,8 @@ class CCPairManager:
|
||||
def wait_for_indexing_completion(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: Wait for an indexing success time after this time"""
|
||||
start = time.monotonic()
|
||||
@@ -430,30 +392,22 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def prune(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def last_pruned(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_str = response.json()
|
||||
@@ -471,8 +425,8 @@ class CCPairManager:
|
||||
def wait_for_prune(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
@@ -496,7 +450,7 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def sync(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""This function triggers a permission sync.
|
||||
Naming / intent of this function probably could use improvement, but currently it's letting
|
||||
@@ -504,22 +458,14 @@ class CCPairManager:
|
||||
"""
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if result.status_code != 409:
|
||||
result.raise_for_status()
|
||||
|
||||
group_sync_result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if group_sync_result.status_code != 409:
|
||||
group_sync_result.raise_for_status()
|
||||
@@ -528,15 +474,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_doc_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
doc_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
doc_sync_response.raise_for_status()
|
||||
doc_sync_response_str = doc_sync_response.json()
|
||||
@@ -553,15 +495,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_group_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> datetime | None:
|
||||
group_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
group_sync_response.raise_for_status()
|
||||
group_sync_response_str = group_sync_response.json()
|
||||
@@ -578,15 +516,11 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def get_doc_sync_statuses(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DocumentSyncStatus]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
doc_sync_statuses: list[DocumentSyncStatus] = []
|
||||
@@ -613,9 +547,9 @@ class CCPairManager:
|
||||
def wait_for_sync(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
number_of_updated_docs: int = 0,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
# Sometimes waiting for a group sync is not necessary
|
||||
should_wait_for_group_sync: bool = True,
|
||||
# Sometimes waiting for a vespa sync is not necessary
|
||||
@@ -703,8 +637,8 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: DATestUser,
|
||||
cc_pair_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
|
||||
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestChatMessage
|
||||
from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -74,9 +73,9 @@ class StreamPacketData(TypedDict, total=False):
|
||||
class ChatSessionManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
persona_id: int = 0,
|
||||
description: str = "Test chat session",
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
@@ -84,11 +83,7 @@ class ChatSessionManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
json=chat_session_creation_req.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
@@ -100,8 +95,8 @@ class ChatSessionManager:
|
||||
def send_message(
|
||||
chat_session_id: UUID,
|
||||
message: str,
|
||||
user_performing_action: DATestUser,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
@@ -126,19 +121,12 @@ class ChatSessionManager:
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
|
||||
streamed_response = ChatSessionManager.analyze_response(response)
|
||||
@@ -167,9 +155,9 @@ class ChatSessionManager:
|
||||
def send_message_with_disconnect(
|
||||
chat_session_id: UUID,
|
||||
message: str,
|
||||
user_performing_action: DATestUser,
|
||||
disconnect_after_packets: int = 0,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
@@ -208,21 +196,14 @@ class ChatSessionManager:
|
||||
llm_override=llm_override,
|
||||
)
|
||||
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
packets_received = 0
|
||||
|
||||
with requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-chat-message",
|
||||
json=chat_message_req.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
cookies=user_performing_action.cookies,
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
@@ -359,15 +340,11 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def get_chat_history(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestChatMessage]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -387,7 +364,7 @@ class ChatSessionManager:
|
||||
def create_chat_message_feedback(
|
||||
message_id: int,
|
||||
is_positive: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
feedback_text: str | None = None,
|
||||
predefined_feedback: str | None = None,
|
||||
) -> None:
|
||||
@@ -399,18 +376,14 @@ class ChatSessionManager:
|
||||
"feedback_text": feedback_text,
|
||||
"predefined_feedback": predefined_feedback,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a chat session and all its related records (messages, agent data, etc.)
|
||||
@@ -420,18 +393,14 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def soft_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Soft delete a chat session (marks as deleted but keeps in database).
|
||||
@@ -442,18 +411,14 @@ class ChatSessionManager:
|
||||
# or make a direct call with hard_delete=False parameter via a new endpoint
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def hard_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Hard delete a chat session (completely removes from database).
|
||||
@@ -462,18 +427,14 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def verify_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been deleted by attempting to retrieve it.
|
||||
@@ -482,11 +443,7 @@ class ChatSessionManager:
|
||||
"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
@@ -494,7 +451,7 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
|
||||
@@ -504,11 +461,7 @@ class ChatSessionManager:
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -520,7 +473,7 @@ class ChatSessionManager:
|
||||
@staticmethod
|
||||
def verify_hard_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been hard deleted (completely removed from DB).
|
||||
@@ -530,11 +483,7 @@ class ChatSessionManager:
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
# For hard delete, even with include_deleted=true, the record should not exist
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import ConnectorUpdateRequest
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -16,13 +15,13 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class ConnectorManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestConnector:
|
||||
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
|
||||
@@ -51,11 +50,7 @@ class ConnectorManager:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector",
|
||||
json=connector_update_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -73,45 +68,33 @@ class ConnectorManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
json=connector.model_dump(exclude={"id"}),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
connector: DATestConnector,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestConnector]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
@@ -127,15 +110,12 @@ class ConnectorManager:
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
connector_id: int, user_performing_action: DATestUser | None = None
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestConnector:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
conn = response.json()
|
||||
|
||||
@@ -6,7 +6,6 @@ import requests
|
||||
from onyx.server.documents.models import CredentialSnapshot
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -14,13 +13,13 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class CredentialManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
admin_public: bool = True,
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
curator_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCredential:
|
||||
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
|
||||
|
||||
@@ -36,11 +35,7 @@ class CredentialManager:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/credential",
|
||||
json=credential_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -57,61 +52,46 @@ class CredentialManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
request = credential.model_dump(include={"name", "credential_json"})
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
credential_id: int, user_performing_action: DATestUser | None = None
|
||||
credential_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> CredentialSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CredentialSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[CredentialSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/credential",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [CredentialSnapshot(**cred) for cred in response.json()]
|
||||
@@ -119,8 +99,8 @@ class CredentialManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
credential: DATestCredential,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_credentials = CredentialManager.get_all(user_performing_action)
|
||||
for fetched_credential in all_credentials:
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import DATestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
|
||||
@@ -22,9 +21,9 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
def _verify_document_permissions(
|
||||
retrieved_doc: dict,
|
||||
cc_pair: DATestCCPair,
|
||||
doc_creating_user: DATestUser,
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
) -> None:
|
||||
acl_keys = set(retrieved_doc.get("access_control_list", {}).keys())
|
||||
print(f"ACL keys: {acl_keys}")
|
||||
@@ -36,12 +35,11 @@ def _verify_document_permissions(
|
||||
" does not have the PUBLIC ACL key"
|
||||
)
|
||||
|
||||
if doc_creating_user is not None:
|
||||
if f"user_email:{doc_creating_user.email}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
|
||||
)
|
||||
if f"user_email:{doc_creating_user.email}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
|
||||
)
|
||||
|
||||
if group_names is not None:
|
||||
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
|
||||
@@ -101,9 +99,9 @@ class DocumentManager:
|
||||
@staticmethod
|
||||
def seed_dummy_docs(
|
||||
cc_pair: DATestCCPair,
|
||||
api_key: DATestAPIKey,
|
||||
num_docs: int = NUM_DOCS,
|
||||
document_ids: list[str] | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
) -> list[SimpleTestDocument]:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
if document_ids is None:
|
||||
@@ -118,12 +116,13 @@ class DocumentManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
json=document,
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
api_key_id = api_key.api_key_id if api_key else ""
|
||||
print(f"Seeding docs for api_key_id={api_key_id} completed successfully.")
|
||||
print(
|
||||
f"Seeding docs for api_key_id={api_key.api_key_id} completed successfully."
|
||||
)
|
||||
return [
|
||||
SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
@@ -136,8 +135,8 @@ class DocumentManager:
|
||||
def seed_doc_with_content(
|
||||
cc_pair: DATestCCPair,
|
||||
content: str,
|
||||
api_key: DATestAPIKey,
|
||||
document_id: str | None = None,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> SimpleTestDocument:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
@@ -153,12 +152,13 @@ class DocumentManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
json=document,
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
api_key_id = api_key.api_key_id if api_key else ""
|
||||
print(f"Seeding doc for api_key_id={api_key_id} completed successfully.")
|
||||
print(
|
||||
f"Seeding doc for api_key_id={api_key.api_key_id} completed successfully."
|
||||
)
|
||||
|
||||
return SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
@@ -169,11 +169,11 @@ class DocumentManager:
|
||||
def verify(
|
||||
vespa_client: vespa_fixture,
|
||||
cc_pair: DATestCCPair,
|
||||
doc_creating_user: DATestUser,
|
||||
# If None, will not check doc sets or groups
|
||||
# If empty list, will check for empty doc sets or groups
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: DATestUser | None = None,
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
doc_ids = [document.id for document in cc_pair.documents]
|
||||
@@ -212,9 +212,9 @@ class DocumentManager:
|
||||
_verify_document_permissions(
|
||||
retrieved_doc,
|
||||
cc_pair,
|
||||
doc_creating_user,
|
||||
doc_set_names,
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -268,11 +268,11 @@ class IngestionManager(DocumentManager):
|
||||
|
||||
@staticmethod
|
||||
def list_all_ingestion_docs(
|
||||
api_key: DATestAPIKey | None = None,
|
||||
api_key: DATestAPIKey,
|
||||
) -> list[dict]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion",
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -280,11 +280,11 @@ class IngestionManager(DocumentManager):
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_id: str,
|
||||
api_key: DATestAPIKey | None = None,
|
||||
api_key: DATestAPIKey,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/onyx-api/ingestion/{document_id}",
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
headers=api_key.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(f"Deleted document {document_id} successfully.")
|
||||
|
||||
@@ -3,7 +3,6 @@ import requests
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -11,7 +10,7 @@ class DocumentSearchManager:
|
||||
@staticmethod
|
||||
def search_documents(
|
||||
query: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Search for documents using the EE search API.
|
||||
@@ -31,11 +30,7 @@ class DocumentSearchManager:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/search/send-search-message",
|
||||
json=search_request.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
result_json = result.json()
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestDocumentSet
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -15,6 +14,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class DocumentSetManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
@@ -22,7 +22,6 @@ class DocumentSetManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
federated_connectors: list[dict[str, Any]] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestDocumentSet:
|
||||
if name is None:
|
||||
name = f"test_doc_set_{str(uuid4())}"
|
||||
@@ -40,11 +39,7 @@ class DocumentSetManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_creation_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -63,7 +58,7 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
@@ -77,11 +72,7 @@ class DocumentSetManager:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_update_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
@@ -89,30 +80,22 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestDocumentSet]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/document-set",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
@@ -132,8 +115,8 @@ class DocumentSetManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_performing_action: DATestUser,
|
||||
document_sets_to_check: list[DATestDocumentSet] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
# wait for document sets to be synced
|
||||
start = time.time()
|
||||
@@ -175,8 +158,8 @@ class DocumentSetManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
document_set: DATestDocumentSet,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
doc_sets = DocumentSetManager.get_all(user_performing_action)
|
||||
for doc_set in doc_sets:
|
||||
|
||||
@@ -10,7 +10,6 @@ import requests
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -18,13 +17,9 @@ class FileManager:
|
||||
@staticmethod
|
||||
def upload_files(
|
||||
files: List[Tuple[str, IO]],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> Tuple[List[FileDescriptor], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
files_param = []
|
||||
@@ -67,15 +62,11 @@ class FileManager:
|
||||
@staticmethod
|
||||
def fetch_uploaded_file(
|
||||
file_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bytes:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{file_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestImageGenerationConfig
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -26,6 +25,7 @@ def _serialize_custom_config(
|
||||
class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
image_provider_id: str | None = None,
|
||||
model_name: str = "gpt-image-1",
|
||||
provider: str = "openai",
|
||||
@@ -35,7 +35,6 @@ class ImageGenerationConfigManager:
|
||||
deployment_name: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
is_default: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a new image generation config with new credentials."""
|
||||
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
|
||||
@@ -53,11 +52,7 @@ class ImageGenerationConfigManager:
|
||||
"custom_config": _serialize_custom_config(custom_config),
|
||||
"is_default": is_default,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -74,13 +69,13 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def create_from_provider(
|
||||
source_llm_provider_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
image_provider_id: str | None = None,
|
||||
model_name: str = "gpt-image-1",
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
is_default: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a new image generation config by cloning from an existing LLM provider."""
|
||||
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
|
||||
@@ -96,11 +91,7 @@ class ImageGenerationConfigManager:
|
||||
"deployment_name": deployment_name,
|
||||
"is_default": is_default,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -116,16 +107,12 @@ class ImageGenerationConfigManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestImageGenerationConfig]:
|
||||
"""Get all image generation configs."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [DATestImageGenerationConfig(**config) for config in response.json()]
|
||||
@@ -133,16 +120,12 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def get_credentials(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Get credentials for an image generation config."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/credentials",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -151,13 +134,13 @@ class ImageGenerationConfigManager:
|
||||
def update(
|
||||
image_provider_id: str,
|
||||
model_name: str,
|
||||
user_performing_action: DATestUser,
|
||||
provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
source_llm_provider_id: int | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Update an existing image generation config."""
|
||||
payload: dict = {
|
||||
@@ -178,14 +161,10 @@ class ImageGenerationConfigManager:
|
||||
f"Got: source_llm_provider_id={source_llm_provider_id}, provider={provider}, api_key={'***' if api_key else None}"
|
||||
)
|
||||
|
||||
headers = {**GENERAL_HEADERS}
|
||||
if user_performing_action:
|
||||
headers.update(user_performing_action.headers)
|
||||
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if not response.ok:
|
||||
print(f"Update failed with status {response.status_code}: {response.text}")
|
||||
@@ -204,40 +183,32 @@ class ImageGenerationConfigManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""Delete an image generation config."""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def set_default(
|
||||
image_provider_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
"""Set an image generation config as the default."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
config: DATestImageGenerationConfig,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Verify that a config exists (or doesn't exist if verify_deleted=True)."""
|
||||
all_configs = ImageGenerationConfigManager.get_all(user_performing_action)
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -86,9 +85,9 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_index_attempt_page(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
page: int = 0,
|
||||
page_size: int = 10,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[IndexAttemptSnapshot]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page,
|
||||
@@ -101,11 +100,7 @@ class IndexAttemptManager:
|
||||
)
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -117,7 +112,7 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> IndexAttemptSnapshot | None:
|
||||
"""Get an IndexAttempt by ID"""
|
||||
index_attempts = IndexAttemptManager.get_index_attempt_page(
|
||||
@@ -134,9 +129,9 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_start(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
index_attempts_to_ignore: list[int] | None = None,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
"""Wait for an IndexAttempt to start"""
|
||||
start = datetime.now()
|
||||
@@ -164,7 +159,7 @@ class IndexAttemptManager:
|
||||
def get_index_attempt_by_id(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> IndexAttemptSnapshot:
|
||||
page_num = 0
|
||||
page_size = 10
|
||||
@@ -190,8 +185,8 @@ class IndexAttemptManager:
|
||||
def wait_for_index_attempt_completion(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = time.monotonic()
|
||||
@@ -223,19 +218,15 @@ class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
include_resolved: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[IndexAttemptErrorPydantic]:
|
||||
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
|
||||
if include_resolved:
|
||||
url += "&include_resolved=true"
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class LLMProviderManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
@@ -26,13 +26,8 @@ class LLMProviderManager:
|
||||
personas: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
set_as_default: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestLLMProvider:
|
||||
email = "Unknown"
|
||||
if user_performing_action:
|
||||
email = user_performing_action.email
|
||||
|
||||
print(f"Seeding LLM Providers for {email}...")
|
||||
print(f"Seeding LLM Providers for {user_performing_action.email}...")
|
||||
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
@@ -60,11 +55,7 @@ class LLMProviderManager:
|
||||
llm_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
json=llm_provider.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
@@ -86,11 +77,7 @@ class LLMProviderManager:
|
||||
if set_as_default:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -99,30 +86,22 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
llm_provider: DATestLLMProvider,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[LLMProviderView]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [LLMProviderView(**ug) for ug in response.json()]
|
||||
@@ -130,8 +109,8 @@ class LLMProviderManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
llm_provider: DATestLLMProvider,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestPersona
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class PersonaManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
@@ -34,7 +34,6 @@ class PersonaManager:
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
display_priority: int | None = None,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
@@ -67,11 +66,7 @@ class PersonaManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request.model_dump(mode="json"),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
persona_data = response.json()
|
||||
@@ -100,6 +95,7 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
@@ -117,7 +113,6 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -151,11 +146,7 @@ class PersonaManager:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request.model_dump(mode="json"),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
updated_persona_data = response.json()
|
||||
@@ -187,15 +178,11 @@ class PersonaManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/persona",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**persona) for persona in response.json()]
|
||||
@@ -203,15 +190,11 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def get_one(
|
||||
persona_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**response.json())]
|
||||
@@ -219,7 +202,7 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_one(
|
||||
persona_id=persona.id,
|
||||
@@ -388,15 +371,11 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@@ -405,18 +384,14 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestPersonaLabel:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona/labels",
|
||||
json={
|
||||
"name": label.name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
@@ -425,15 +400,11 @@ class PersonaLabelManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestPersonaLabel]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/labels",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [DATestPersonaLabel(**label) for label in response.json()]
|
||||
@@ -441,18 +412,14 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def update(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestPersonaLabel:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
||||
json={
|
||||
"label_name": label.name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return label
|
||||
@@ -460,22 +427,18 @@ class PersonaLabelManager:
|
||||
@staticmethod
|
||||
def delete(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
label: DATestPersonaLabel,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
all_labels = PersonaLabelManager.get_all(user_performing_action)
|
||||
for fetched_label in all_labels:
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.features.projects.models import UserProjectSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -20,7 +19,7 @@ class ProjectManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/create",
|
||||
params={"name": name},
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return UserProjectSnapshot.model_validate(response.json())
|
||||
@@ -32,7 +31,7 @@ class ProjectManager:
|
||||
"""Get all projects for a user via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
|
||||
@@ -45,7 +44,7 @@ class ProjectManager:
|
||||
"""Delete a project via API."""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
return response.status_code == 204
|
||||
|
||||
@@ -57,7 +56,7 @@ class ProjectManager:
|
||||
"""Verify that a project has been deleted by ensuring it's not in list."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects",
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
projects = [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
|
||||
@@ -66,16 +65,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def verify_files_unlinked(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Verify that all files have been unlinked from the project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/files/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return True
|
||||
@@ -87,16 +82,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def verify_chat_sessions_unlinked(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Verify that all chat sessions have been unlinked from the project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return True
|
||||
@@ -144,16 +135,12 @@ class ProjectManager:
|
||||
@staticmethod
|
||||
def get_project_files(
|
||||
project_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> List[UserFileSnapshot]:
|
||||
"""Get all files associated with a project via API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/files/{project_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return []
|
||||
@@ -170,7 +157,7 @@ class ProjectManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/user/projects/{project_id}/instructions",
|
||||
json={"instructions": instructions},
|
||||
headers=user_performing_action.headers or GENERAL_HEADERS,
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return (response.json() or {}).get("instructions") or ""
|
||||
|
||||
@@ -10,19 +10,18 @@ from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class QueryHistoryManager:
|
||||
@staticmethod
|
||||
def get_query_history_page(
|
||||
user_performing_action: DATestUser,
|
||||
page_num: int = 0,
|
||||
page_size: int = 10,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page_num,
|
||||
@@ -37,11 +36,7 @@ class QueryHistoryManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
@@ -53,24 +48,20 @@ class QueryHistoryManager:
|
||||
@staticmethod
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID | str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> ChatSessionSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ChatSessionSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_query_history_as_csv(
|
||||
user_performing_action: DATestUser,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> tuple[CaseInsensitiveDict[str], str]:
|
||||
query_params: dict[str, str | int] = {}
|
||||
if start_time:
|
||||
@@ -80,11 +71,7 @@ class QueryHistoryManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Optional
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -13,13 +12,9 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class SettingsManager:
|
||||
@staticmethod
|
||||
def get_settings(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.get(
|
||||
@@ -38,13 +33,9 @@ class SettingsManager:
|
||||
@staticmethod
|
||||
def update_settings(
|
||||
settings: DATestSettings,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers = user_performing_action.headers
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
payload = settings.model_dump()
|
||||
@@ -65,7 +56,7 @@ class SettingsManager:
|
||||
@staticmethod
|
||||
def get_setting(
|
||||
key: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> Optional[Any]:
|
||||
settings, error = SettingsManager.get_settings(user_performing_action)
|
||||
if error:
|
||||
|
||||
@@ -8,7 +8,6 @@ from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -26,15 +25,11 @@ def generate_auth_token() -> str:
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> AllUsersResponse:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -50,7 +45,8 @@ class TenantManager:
|
||||
|
||||
@staticmethod
|
||||
def verify_user_in_tenant(
|
||||
user: DATestUser, user_performing_action: DATestUser | None = None
|
||||
user: DATestUser,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
all_users = TenantManager.get_all_users(user_performing_action)
|
||||
for accepted_user in all_users.accepted:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestTool
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -9,15 +8,11 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
def list_tools(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestTool]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/tool",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
|
||||
@@ -7,6 +7,8 @@ import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import ANONYMOUS_USER_UUID
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -25,6 +27,23 @@ def build_email(name: str) -> str:
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def get_anonymous_user() -> DATestUser:
|
||||
"""Get a DATestUser representing the anonymous user.
|
||||
|
||||
Anonymous users are real users in the database with LIMITED role.
|
||||
They don't have login cookies - requests are made with GENERAL_HEADERS.
|
||||
The anonymous_user_enabled setting must be True for these requests to work.
|
||||
"""
|
||||
return DATestUser(
|
||||
id=ANONYMOUS_USER_UUID,
|
||||
email=ANONYMOUS_USER_EMAIL,
|
||||
password="",
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.LIMITED,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
@@ -227,12 +246,12 @@ class UserManager:
|
||||
|
||||
@staticmethod
|
||||
def get_user_page(
|
||||
user_performing_action: DATestUser,
|
||||
page_num: int = 0,
|
||||
page_size: int = 10,
|
||||
search_query: str | None = None,
|
||||
role_filter: list[UserRole] | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[FullUserSnapshot]:
|
||||
query_params: dict[str, str | list[str] | int] = {
|
||||
"page_num": page_num,
|
||||
@@ -247,11 +266,7 @@ class UserManager:
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import requests
|
||||
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
@@ -14,10 +13,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
class UserGroupManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
user_performing_action: DATestUser,
|
||||
name: str | None = None,
|
||||
user_ids: list[str] | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestUserGroup:
|
||||
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
|
||||
|
||||
@@ -29,11 +28,7 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
test_user_group = DATestUserGroup(
|
||||
@@ -47,31 +42,23 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def edit(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
json=user_group.model_dump(),
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -79,7 +66,7 @@ class UserGroupManager:
|
||||
def add_users(
|
||||
user_group: DATestUserGroup,
|
||||
user_ids: list[str],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestUserGroup:
|
||||
request = {
|
||||
"user_ids": user_ids,
|
||||
@@ -88,11 +75,7 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
|
||||
json=request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -107,8 +90,8 @@ class UserGroupManager:
|
||||
def set_curator_status(
|
||||
test_user_group: DATestUserGroup,
|
||||
user_to_set_as_curator: DATestUser,
|
||||
user_performing_action: DATestUser,
|
||||
is_curator: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
set_curator_request = {
|
||||
"user_id": user_to_set_as_curator.id,
|
||||
@@ -117,25 +100,17 @@ class UserGroupManager:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
|
||||
json=set_curator_request,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[UserGroup]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
@@ -143,8 +118,8 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def verify(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
for fetched_user_group in all_user_groups:
|
||||
@@ -167,8 +142,8 @@ class UserGroupManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_performing_action: DATestUser,
|
||||
user_groups_to_check: list[DATestUserGroup] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
@@ -198,7 +173,7 @@ class UserGroupManager:
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_groups_to_check: list[DATestUserGroup],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
user_performing_action: DATestUser,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}
|
||||
|
||||
@@ -88,11 +88,8 @@ def reset() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None: # noqa: ARG001
|
||||
try:
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
return None
|
||||
def new_admin_user(reset: None) -> DATestUser: # noqa: ARG001
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -182,18 +179,18 @@ def reset_multitenant() -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_generation_config(
|
||||
admin_user: DATestUser | None,
|
||||
admin_user: DATestUser,
|
||||
) -> DATestImageGenerationConfig:
|
||||
"""Create a default image generation config for tests."""
|
||||
return ImageGenerationConfigManager.create(
|
||||
is_default=True,
|
||||
user_performing_action=admin_user,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,3 +60,44 @@ def test_me_endpoint_returns_authenticated_user_info(
|
||||
assert data.get("is_anonymous_user") is not True
|
||||
assert data["email"] == admin_user.email
|
||||
assert data["role"] == "admin"
|
||||
|
||||
|
||||
def test_anonymous_user_can_access_persona_when_enabled(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Verify that anonymous users can access limited endpoints when enabled."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(anonymous_user_enabled=True),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
anon_user = UserManager.get_anonymous_user()
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
headers=anon_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_anonymous_user_denied_persona_when_disabled(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Verify that anonymous users cannot access endpoints when disabled."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(anonymous_user_enabled=False),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
anon_user = UserManager.get_anonymous_user()
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
headers=anon_user.headers,
|
||||
)
|
||||
# 403 is returned - BasicAuthenticationError uses HTTP 403 for all auth failures
|
||||
assert response.status_code == 403
|
||||
|
||||
@@ -11,8 +11,8 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
def _verify_index_attempt_pagination(
|
||||
cc_pair_id: int,
|
||||
index_attempt_ids: list[int],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_attempts: list[int] = []
|
||||
last_time_started = None # Track the last time_started seen
|
||||
|
||||
@@ -207,7 +207,9 @@ def test_mcp_search_respects_acl_filters(
|
||||
cc_pair_ids=[restricted_cc_pair.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync([user_group], user_performing_action=admin_user)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_performing_action=admin_user, user_groups_to_check=[user_group]
|
||||
)
|
||||
|
||||
restricted_doc_content = "MCP restricted knowledge base document"
|
||||
_seed_document_and_wait_for_indexing(
|
||||
|
||||
@@ -14,11 +14,11 @@ from tests.integration.tests.query_history.utils import (
|
||||
|
||||
def _verify_query_history_pagination(
|
||||
chat_sessions: list[DAQueryHistoryEntry],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_sessions: list[str] = []
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -59,7 +58,7 @@ def test_add_users_to_group_invalid_user(reset: None) -> None: # noqa: ARG001
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
|
||||
json={"user_ids": [invalid_user_id]},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@@ -9,11 +9,11 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
# to verify that the pagination and filtering works as expected.
|
||||
def _verify_user_pagination(
|
||||
users: list[DATestUser],
|
||||
user_performing_action: DATestUser,
|
||||
page_size: int = 5,
|
||||
search_query: str | None = None,
|
||||
role_filter: list[UserRole] | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_users: list[FullUserSnapshot] = []
|
||||
|
||||
|
||||
@@ -158,14 +158,14 @@ python ./scripts/dev_run_background_jobs.py
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
$env:AUTH_TYPE='basic'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
@@ -126,7 +126,9 @@ Resources:
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- secretsmanager:GetSecretValue
|
||||
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
Resource:
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
|
||||
|
||||
Outputs:
|
||||
OutputEcsCluster:
|
||||
|
||||
@@ -167,10 +167,12 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: disabled
|
||||
Value: basic
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
|
||||
@@ -166,9 +166,11 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: disabled
|
||||
Value: basic
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
env_file:
|
||||
- .env_eval
|
||||
environment:
|
||||
- AUTH_TYPE=disabled
|
||||
- AUTH_TYPE=basic
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
@@ -59,7 +59,7 @@ services:
|
||||
- .env_eval
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=disabled
|
||||
- AUTH_TYPE=basic
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
|
||||
@@ -20,8 +20,11 @@ IMAGE_TAG=latest
|
||||
|
||||
## Auth Settings
|
||||
### https://docs.onyx.app/deployment/authentication
|
||||
AUTH_TYPE=disabled
|
||||
AUTH_TYPE=basic
|
||||
# SESSION_EXPIRE_TIME_SECONDS=
|
||||
### Required for basic auth - used for signing password reset and verification tokens
|
||||
### Generate a secure value with: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=OnyxDevSecret1!
|
||||
### Recommend to set this for security
|
||||
# ENCRYPTION_KEY_SECRET=
|
||||
### Optional
|
||||
|
||||
@@ -654,16 +654,9 @@ else
|
||||
sed -i.bak "s/^IMAGE_TAG=.*/IMAGE_TAG=$VERSION/" "$ENV_FILE"
|
||||
print_success "IMAGE_TAG set to $VERSION"
|
||||
|
||||
# Configure authentication settings based on selection
|
||||
if [ "$AUTH_SCHEMA" = "disabled" ]; then
|
||||
# Disable authentication in .env file
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=disabled/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Authentication disabled in configuration"
|
||||
else
|
||||
# Enable basic authentication
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Basic authentication enabled in configuration"
|
||||
fi
|
||||
# Configure basic authentication (default)
|
||||
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
|
||||
print_success "Basic authentication enabled in configuration"
|
||||
|
||||
# Configure Craft based on flag or if using a craft-* image tag
|
||||
# By default, env.template has Craft commented out (disabled)
|
||||
|
||||
@@ -1167,10 +1167,25 @@ auth:
|
||||
values:
|
||||
opensearch_admin_username: "admin"
|
||||
opensearch_admin_password: "OnyxDev1!"
|
||||
userauth:
|
||||
# -- Required when AUTH_TYPE is "basic". Used for signing password reset
|
||||
# tokens, email verification tokens, and JWT tokens.
|
||||
enabled: true
|
||||
# -- Overwrite the default secret name, ignored if existingSecret is defined
|
||||
secretName: 'onyx-userauth'
|
||||
# -- Use a secret specified elsewhere
|
||||
existingSecret: ""
|
||||
# -- This defines the env var to secret map
|
||||
secretKeys:
|
||||
USER_AUTH_SECRET: user_auth_secret
|
||||
# -- Secret value. CHANGE THIS FOR PRODUCTION.
|
||||
# Generate a secure value with: openssl rand -hex 32
|
||||
values:
|
||||
user_auth_secret: "OnyxDevSecret1!"
|
||||
|
||||
configMap:
|
||||
# Change this for production uses unless Onyx is only accessible behind VPN
|
||||
AUTH_TYPE: "disabled"
|
||||
AUTH_TYPE: "basic"
|
||||
# 1 Day Default
|
||||
SESSION_EXPIRE_TIME_SECONDS: "86400"
|
||||
# Can be something like onyx.app, as an extra double-check
|
||||
|
||||
Reference in New Issue
Block a user