Compare commits

...

5 Commits

Author SHA1 Message Date
Jessica Singh
d0c9d36692 changes to ecs fargate 2026-02-18 12:17:13 -08:00
Jessica Singh
e03bf2a6a3 sign 2026-02-17 13:30:24 -08:00
Jessica Singh
7c8c7c9d91 add anon tests 2026-02-17 11:34:17 -08:00
Jessica Singh
89d8521f37 changes 2026-02-16 23:16:41 -08:00
Jessica Singh
24a0e08ee2 changes 2026-02-16 23:15:14 -08:00
42 changed files with 373 additions and 661 deletions

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
import jwt
@@ -20,7 +21,13 @@ logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -1,4 +1,5 @@
import json
import os
import random
import secrets
import string
@@ -143,10 +144,22 @@ def is_user_admin(user: User) -> bool:
def verify_auth_setting() -> None:
if AUTH_TYPE == AuthType.CLOUD:
"""Log warnings for AUTH_TYPE issues.
This only runs on app startup not during migrations/scripts.
"""
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "cloud":
raise ValueError(
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
"'cloud' is not a valid auth type for self-hosted deployments."
)
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -85,19 +85,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
#####
# Auth Configs
#####
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Defaulting to 'basic'. Please update your configuration. "
"Your existing data will be migrated automatically."
)
_auth_type_str = AuthType.BASIC.value
try:
# Silently default to basic - warnings/errors logged in verify_auth_setting()
# which only runs on app startup, not during migrations/scripts
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
AUTH_TYPE = AuthType(_auth_type_str)
except ValueError:
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
else:
AUTH_TYPE = AuthType.BASIC
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))

View File

@@ -1831,7 +1831,7 @@ def get_connector_by_id(
@router.post("/connector-request")
def submit_connector_request(
request_data: ConnectorRequestSubmission,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
) -> StatusResponse:
"""
Submit a connector request for Cloud deployments.
@@ -1844,7 +1844,7 @@ def submit_connector_request(
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
# Get user identifier for telemetry
user_email = user.email if user else None
user_email = user.email
distinct_id = user_email or tenant_id
# Track connector request via PostHog telemetry (Cloud only)

View File

@@ -57,9 +57,6 @@ def list_messages(
db_session: Session = Depends(get_session),
) -> MessageListResponse:
"""Get all messages for a build session."""
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
session_manager = SessionManager(db_session)
messages = session_manager.list_messages(session_id, user.id)

View File

@@ -43,18 +43,14 @@ def _require_opensearch(db_session: Session) -> None:
)
def _get_user_access_info(
user: User | None, db_session: Session
) -> tuple[str | None, list[str]]:
if not user:
return None, []
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
return user.email, get_user_external_group_ids(db_session, user)
@router.get(HIERARCHY_NODES_LIST_PATH)
def list_accessible_hierarchy_nodes(
source: DocumentSource,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodesResponse:
_require_opensearch(db_session)
@@ -81,7 +77,7 @@ def list_accessible_hierarchy_nodes(
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
def list_accessible_hierarchy_node_documents(
documents_request: HierarchyNodeDocumentsRequest,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodeDocumentsResponse:
_require_opensearch(db_session)

View File

@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
@router.get("/servers", response_model=MCPServersResponse)
def get_mcp_servers_for_user(
db: Session = Depends(get_session),
user: User | None = Depends(current_user),
user: User = Depends(current_user),
) -> MCPServersResponse:
"""List all MCP servers for use in agent configuration and chat UI.

View File

@@ -13,9 +13,9 @@ from tests.integration.common_utils.test_models import DATestUser
class APIKeyManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: DATestUser | None = None,
) -> DATestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
@@ -25,11 +25,7 @@ class APIKeyManager:
api_key_response = requests.post(
f"{API_SERVER_URL}/admin/api-key",
json=api_key_request.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
@@ -48,29 +44,21 @@ class APIKeyManager:
@staticmethod
def delete(
api_key: DATestAPIKey,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
api_key_response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
api_key_response.raise_for_status()
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
@@ -78,8 +66,8 @@ class APIKeyManager:
@staticmethod
def verify(
api_key: DATestAPIKey,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action

View File

@@ -17,7 +17,6 @@ from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import DocumentSyncStatus
from tests.integration.common_utils.config import api_config
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
@@ -28,10 +27,10 @@ from tests.integration.common_utils.test_models import DATestUser
def _cc_pair_creator(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
@@ -40,17 +39,12 @@ def _cc_pair_creator(
connector_credential_pair_metadata = api.ConnectorCredentialPairMetadata(
name=name, access_type=access_type, groups=groups or []
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
api_response: api.StatusResponseInt = (
api_instance.associate_credential_to_connector(
connector_id,
credential_id,
connector_credential_pair_metadata,
_headers=headers,
_headers=user_performing_action.headers,
)
)
@@ -67,6 +61,7 @@ def _cc_pair_creator(
class CCPairManager:
@staticmethod
def create_from_scratch(
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
@@ -74,26 +69,25 @@ class CCPairManager:
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
user_performing_action=user_performing_action,
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
refresh_freq=refresh_freq,
)
credential = CredentialManager.create(
user_performing_action=user_performing_action,
credential_json=credential_json,
name=name,
source=source,
curator_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
cc_pair = _cc_pair_creator(
connector_id=connector.id,
@@ -109,10 +103,10 @@ class CCPairManager:
def create(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
cc_pair = _cc_pair_creator(
connector_id=connector_id,
@@ -127,39 +121,31 @@ class CCPairManager:
@staticmethod
def pause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "PAUSED"},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
@staticmethod
def unpause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "ACTIVE"},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
@@ -168,26 +154,18 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json=cc_pair_identifier.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
@staticmethod
def get_single(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> CCPairFullInfo | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
cc_pair_json = response.json()
@@ -196,15 +174,11 @@ class CCPairManager:
@staticmethod
def get_indexing_status_by_id(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> ConnectorIndexingStatusLite | None:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -219,15 +193,11 @@ class CCPairManager:
@staticmethod
def get_indexing_statuses(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[ConnectorIndexingStatusLite]:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -241,15 +211,11 @@ class CCPairManager:
@staticmethod
def get_connector_statuses(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[ConnectorStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/status",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [ConnectorStatus(**status) for status in response.json()]
@@ -257,8 +223,8 @@ class CCPairManager:
@staticmethod
def verify(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_connector_statuses(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
@@ -285,7 +251,7 @@ class CCPairManager:
def run_once(
cc_pair: DATestCCPair,
from_beginning: bool,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
@@ -295,19 +261,15 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
json=body,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
@staticmethod
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -342,9 +304,9 @@ class CCPairManager:
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -393,8 +355,8 @@ class CCPairManager:
def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
@@ -430,30 +392,22 @@ class CCPairManager:
@staticmethod
def prune(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
@staticmethod
def last_pruned(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> datetime | None:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
response_str = response.json()
@@ -471,8 +425,8 @@ class CCPairManager:
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
@@ -496,7 +450,7 @@ class CCPairManager:
@staticmethod
def sync(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
"""This function triggers a permission sync.
Naming / intent of this function probably could use improvement, but currently it's letting
@@ -504,22 +458,14 @@ class CCPairManager:
"""
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if result.status_code != 409:
result.raise_for_status()
group_sync_result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if group_sync_result.status_code != 409:
group_sync_result.raise_for_status()
@@ -528,15 +474,11 @@ class CCPairManager:
@staticmethod
def get_doc_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> datetime | None:
doc_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
doc_sync_response.raise_for_status()
doc_sync_response_str = doc_sync_response.json()
@@ -553,15 +495,11 @@ class CCPairManager:
@staticmethod
def get_group_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> datetime | None:
group_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
group_sync_response.raise_for_status()
group_sync_response_str = group_sync_response.json()
@@ -578,15 +516,11 @@ class CCPairManager:
@staticmethod
def get_doc_sync_statuses(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DocumentSyncStatus]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
doc_sync_statuses: list[DocumentSyncStatus] = []
@@ -613,9 +547,9 @@ class CCPairManager:
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
number_of_updated_docs: int = 0,
user_performing_action: DATestUser | None = None,
# Sometimes waiting for a group sync is not necessary
should_wait_for_group_sync: bool = True,
# Sometimes waiting for a vespa sync is not necessary
@@ -703,8 +637,8 @@ class CCPairManager:
@staticmethod
def wait_for_deletion_completion(
user_performing_action: DATestUser,
cc_pair_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.

View File

@@ -17,7 +17,6 @@ from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import StreamingType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestChatMessage
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
@@ -74,9 +73,9 @@ class StreamPacketData(TypedDict, total=False):
class ChatSessionManager:
@staticmethod
def create(
user_performing_action: DATestUser,
persona_id: int = 0,
description: str = "Test chat session",
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
@@ -84,11 +83,7 @@ class ChatSessionManager:
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",
json=chat_session_creation_req.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
@@ -100,8 +95,8 @@ class ChatSessionManager:
def send_message(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -126,19 +121,12 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
response = requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=headers,
headers=user_performing_action.headers,
stream=True,
cookies=cookies,
cookies=user_performing_action.cookies,
)
streamed_response = ChatSessionManager.analyze_response(response)
@@ -167,9 +155,9 @@ class ChatSessionManager:
def send_message_with_disconnect(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
disconnect_after_packets: int = 0,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -208,21 +196,14 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
packets_received = 0
with requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=headers,
headers=user_performing_action.headers,
stream=True,
cookies=cookies,
cookies=user_performing_action.cookies,
) as response:
for line in response.iter_lines():
if not line:
@@ -359,15 +340,11 @@ class ChatSessionManager:
@staticmethod
def get_chat_history(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -387,7 +364,7 @@ class ChatSessionManager:
def create_chat_message_feedback(
message_id: int,
is_positive: bool,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
feedback_text: str | None = None,
predefined_feedback: str | None = None,
) -> None:
@@ -399,18 +376,14 @@ class ChatSessionManager:
"feedback_text": feedback_text,
"predefined_feedback": predefined_feedback,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Delete a chat session and all its related records (messages, agent data, etc.)
@@ -420,18 +393,14 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
return response.ok
@staticmethod
def soft_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Soft delete a chat session (marks as deleted but keeps in database).
@@ -442,18 +411,14 @@ class ChatSessionManager:
# or make a direct call with hard_delete=False parameter via a new endpoint
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
return response.ok
@staticmethod
def hard_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Hard delete a chat session (completely removes from database).
@@ -462,18 +427,14 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
return response.ok
@staticmethod
def verify_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Verify that a chat session has been deleted by attempting to retrieve it.
@@ -482,11 +443,7 @@ class ChatSessionManager:
"""
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@@ -494,7 +451,7 @@ class ChatSessionManager:
@staticmethod
def verify_soft_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
@@ -504,11 +461,7 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if response.status_code == 200:
@@ -520,7 +473,7 @@ class ChatSessionManager:
@staticmethod
def verify_hard_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""
Verify that a chat session has been hard deleted (completely removed from DB).
@@ -530,11 +483,7 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
# For hard delete, even with include_deleted=true, the record should not exist

View File

@@ -8,7 +8,6 @@ from onyx.db.enums import AccessType
from onyx.server.documents.models import ConnectorUpdateRequest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestUser
@@ -16,13 +15,13 @@ from tests.integration.common_utils.test_models import DATestUser
class ConnectorManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
@@ -51,11 +50,7 @@ class ConnectorManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector",
json=connector_update_request.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -73,45 +68,33 @@ class ConnectorManager:
@staticmethod
def edit(
connector: DATestConnector,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
json=connector.model_dump(exclude={"id"}),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def delete(
connector: DATestConnector,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestConnector]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [
@@ -127,15 +110,12 @@ class ConnectorManager:
@staticmethod
def get(
connector_id: int, user_performing_action: DATestUser | None = None
connector_id: int,
user_performing_action: DATestUser,
) -> DATestConnector:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
conn = response.json()

View File

@@ -6,7 +6,6 @@ import requests
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
@@ -14,13 +13,13 @@ from tests.integration.common_utils.test_models import DATestUser
class CredentialManager:
@staticmethod
def create(
user_performing_action: DATestUser,
credential_json: dict[str, Any] | None = None,
admin_public: bool = True,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
curator_public: bool = True,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCredential:
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
@@ -36,11 +35,7 @@ class CredentialManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/credential",
json=credential_request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -57,61 +52,46 @@ class CredentialManager:
@staticmethod
def edit(
credential: DATestCredential,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
request = credential.model_dump(include={"name", "credential_json"})
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
json=request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def delete(
credential: DATestCredential,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def get(
credential_id: int, user_performing_action: DATestUser | None = None
credential_id: int,
user_performing_action: DATestUser,
) -> CredentialSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return CredentialSnapshot(**response.json())
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[CredentialSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/manage/credential",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [CredentialSnapshot(**cred) for cred in response.json()]
@@ -119,8 +99,8 @@ class CredentialManager:
@staticmethod
def verify(
credential: DATestCredential,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_credentials = CredentialManager.get_all(user_performing_action)
for fetched_credential in all_credentials:

View File

@@ -10,7 +10,6 @@ from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import DATestAPIKey
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
@@ -22,9 +21,9 @@ from tests.integration.common_utils.vespa import vespa_fixture
def _verify_document_permissions(
retrieved_doc: dict,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
) -> None:
acl_keys = set(retrieved_doc.get("access_control_list", {}).keys())
print(f"ACL keys: {acl_keys}")
@@ -36,12 +35,11 @@ def _verify_document_permissions(
" does not have the PUBLIC ACL key"
)
if doc_creating_user is not None:
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if group_names is not None:
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
@@ -101,9 +99,9 @@ class DocumentManager:
@staticmethod
def seed_dummy_docs(
cc_pair: DATestCCPair,
api_key: DATestAPIKey,
num_docs: int = NUM_DOCS,
document_ids: list[str] | None = None,
api_key: DATestAPIKey | None = None,
) -> list[SimpleTestDocument]:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_ids is None:
@@ -118,12 +116,13 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers if api_key else GENERAL_HEADERS,
headers=api_key.headers,
)
response.raise_for_status()
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding docs for api_key_id={api_key_id} completed successfully.")
print(
f"Seeding docs for api_key_id={api_key.api_key_id} completed successfully."
)
return [
SimpleTestDocument(
id=document["document"]["id"],
@@ -136,8 +135,8 @@ class DocumentManager:
def seed_doc_with_content(
cc_pair: DATestCCPair,
content: str,
api_key: DATestAPIKey,
document_id: str | None = None,
api_key: DATestAPIKey | None = None,
metadata: dict | None = None,
) -> SimpleTestDocument:
# Use provided document_ids if available, otherwise generate random UUIDs
@@ -153,12 +152,13 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers if api_key else GENERAL_HEADERS,
headers=api_key.headers,
)
response.raise_for_status()
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding doc for api_key_id={api_key_id} completed successfully.")
print(
f"Seeding doc for api_key_id={api_key.api_key_id} completed successfully."
)
return SimpleTestDocument(
id=document["document"]["id"],
@@ -169,11 +169,11 @@ class DocumentManager:
def verify(
vespa_client: vespa_fixture,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
# If None, will not check doc sets or groups
# If empty list, will check for empty doc sets or groups
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
verify_deleted: bool = False,
) -> None:
doc_ids = [document.id for document in cc_pair.documents]
@@ -212,9 +212,9 @@ class DocumentManager:
_verify_document_permissions(
retrieved_doc,
cc_pair,
doc_creating_user,
doc_set_names,
group_names,
doc_creating_user,
)
@staticmethod
@@ -268,11 +268,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def list_all_ingestion_docs(
api_key: DATestAPIKey | None = None,
api_key: DATestAPIKey,
) -> list[dict]:
response = requests.get(
f"{API_SERVER_URL}/onyx-api/ingestion",
headers=api_key.headers if api_key else GENERAL_HEADERS,
headers=api_key.headers,
)
response.raise_for_status()
return response.json()
@@ -280,11 +280,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def delete(
document_id: str,
api_key: DATestAPIKey | None = None,
api_key: DATestAPIKey,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/onyx-api/ingestion/{document_id}",
headers=api_key.headers if api_key else GENERAL_HEADERS,
headers=api_key.headers,
)
response.raise_for_status()
print(f"Deleted document {document_id} successfully.")

View File

@@ -3,7 +3,6 @@ import requests
from ee.onyx.server.query_and_chat.models import SearchFullResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -11,7 +10,7 @@ class DocumentSearchManager:
@staticmethod
def search_documents(
query: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[str]:
"""
Search for documents using the EE search API.
@@ -31,11 +30,7 @@ class DocumentSearchManager:
result = requests.post(
url=f"{API_SERVER_URL}/search/send-search-message",
json=search_request.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
result.raise_for_status()
result_json = result.json()

View File

@@ -6,7 +6,6 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestDocumentSet
from tests.integration.common_utils.test_models import DATestUser
@@ -15,6 +14,7 @@ from tests.integration.common_utils.test_models import DATestUser
class DocumentSetManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
cc_pair_ids: list[int] | None = None,
@@ -22,7 +22,6 @@ class DocumentSetManager:
users: list[str] | None = None,
groups: list[int] | None = None,
federated_connectors: list[dict[str, Any]] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestDocumentSet:
if name is None:
name = f"test_doc_set_{str(uuid4())}"
@@ -40,11 +39,7 @@ class DocumentSetManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -63,7 +58,7 @@ class DocumentSetManager:
@staticmethod
def edit(
document_set: DATestDocumentSet,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
doc_set_update_request = {
"id": document_set.id,
@@ -77,11 +72,7 @@ class DocumentSetManager:
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_update_request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return True
@@ -89,30 +80,22 @@ class DocumentSetManager:
@staticmethod
def delete(
document_set: DATestDocumentSet,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestDocumentSet]:
response = requests.get(
f"{API_SERVER_URL}/manage/document-set",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [
@@ -132,8 +115,8 @@ class DocumentSetManager:
@staticmethod
def wait_for_sync(
user_performing_action: DATestUser,
document_sets_to_check: list[DATestDocumentSet] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
# wait for document sets to be synced
start = time.time()
@@ -175,8 +158,8 @@ class DocumentSetManager:
@staticmethod
def verify(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
doc_sets = DocumentSetManager.get_all(user_performing_action)
for doc_set in doc_sets:

View File

@@ -10,7 +10,6 @@ import requests
from onyx.file_store.models import FileDescriptor
from onyx.server.documents.models import FileUploadResponse
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -18,13 +17,9 @@ class FileManager:
@staticmethod
def upload_files(
files: List[Tuple[str, IO]],
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> Tuple[List[FileDescriptor], str]:
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers = user_performing_action.headers
headers.pop("Content-Type", None)
files_param = []
@@ -67,15 +62,11 @@ class FileManager:
@staticmethod
def fetch_uploaded_file(
file_id: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bytes:
response = requests.get(
f"{API_SERVER_URL}/chat/file/{file_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.content

View File

@@ -6,7 +6,6 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestImageGenerationConfig
from tests.integration.common_utils.test_models import DATestUser
@@ -26,6 +25,7 @@ def _serialize_custom_config(
class ImageGenerationConfigManager:
@staticmethod
def create(
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
provider: str = "openai",
@@ -35,7 +35,6 @@ class ImageGenerationConfigManager:
deployment_name: str | None = None,
custom_config: dict[str, Any] | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config with new credentials."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -53,11 +52,7 @@ class ImageGenerationConfigManager:
"custom_config": _serialize_custom_config(custom_config),
"is_default": is_default,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
data = response.json()
@@ -74,13 +69,13 @@ class ImageGenerationConfigManager:
@staticmethod
def create_from_provider(
source_llm_provider_id: int,
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config by cloning from an existing LLM provider."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -96,11 +91,7 @@ class ImageGenerationConfigManager:
"deployment_name": deployment_name,
"is_default": is_default,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
data = response.json()
@@ -116,16 +107,12 @@ class ImageGenerationConfigManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestImageGenerationConfig]:
"""Get all image generation configs."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [DATestImageGenerationConfig(**config) for config in response.json()]
@@ -133,16 +120,12 @@ class ImageGenerationConfigManager:
@staticmethod
def get_credentials(
image_provider_id: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> dict:
"""Get credentials for an image generation config."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/credentials",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@@ -151,13 +134,13 @@ class ImageGenerationConfigManager:
def update(
image_provider_id: str,
model_name: str,
user_performing_action: DATestUser,
provider: str | None = None,
api_key: str | None = None,
source_llm_provider_id: int | None = None,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Update an existing image generation config."""
payload: dict = {
@@ -178,14 +161,10 @@ class ImageGenerationConfigManager:
f"Got: source_llm_provider_id={source_llm_provider_id}, provider={provider}, api_key={'***' if api_key else None}"
)
headers = {**GENERAL_HEADERS}
if user_performing_action:
headers.update(user_performing_action.headers)
response = requests.put(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
json=payload,
headers=headers,
headers=user_performing_action.headers,
)
if not response.ok:
print(f"Update failed with status {response.status_code}: {response.text}")
@@ -204,40 +183,32 @@ class ImageGenerationConfigManager:
@staticmethod
def delete(
image_provider_id: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
"""Delete an image generation config."""
response = requests.delete(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def set_default(
image_provider_id: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
"""Set an image generation config as the default."""
response = requests.post(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/default",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def verify(
config: DATestImageGenerationConfig,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
"""Verify that a config exists (or doesn't exist if verify_deleted=True)."""
all_configs = ImageGenerationConfigManager.get_all(user_performing_action)

View File

@@ -14,7 +14,6 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
@@ -86,9 +85,9 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_page(
cc_pair_id: int,
user_performing_action: DATestUser,
page: int = 0,
page_size: int = 10,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[IndexAttemptSnapshot]:
query_params: dict[str, str | int] = {
"page_num": page,
@@ -101,11 +100,7 @@ class IndexAttemptManager:
)
response = requests.get(
url=url,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
data = response.json()
@@ -117,7 +112,7 @@ class IndexAttemptManager:
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
@@ -134,9 +129,9 @@ class IndexAttemptManager:
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
user_performing_action: DATestUser,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
@@ -164,7 +159,7 @@ class IndexAttemptManager:
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
@@ -190,8 +185,8 @@ class IndexAttemptManager:
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = time.monotonic()
@@ -223,19 +218,15 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
data = response.json()

View File

@@ -8,7 +8,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
class LLMProviderManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
provider: str | None = None,
api_key: str | None = None,
@@ -26,13 +26,8 @@ class LLMProviderManager:
personas: list[int] | None = None,
is_public: bool | None = None,
set_as_default: bool = True,
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
email = "Unknown"
if user_performing_action:
email = user_performing_action.email
print(f"Seeding LLM Providers for {email}...")
print(f"Seeding LLM Providers for {user_performing_action.email}...")
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
@@ -60,11 +55,7 @@ class LLMProviderManager:
llm_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
json=llm_provider.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
llm_response.raise_for_status()
response_data = llm_response.json()
@@ -86,11 +77,7 @@ class LLMProviderManager:
if set_as_default:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
set_default_response.raise_for_status()
@@ -99,30 +86,22 @@ class LLMProviderManager:
@staticmethod
def delete(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
@@ -130,8 +109,8 @@ class LLMProviderManager:
@staticmethod
def verify(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
for fetched_llm_provider in all_llm_providers:

View File

@@ -7,7 +7,6 @@ from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
@@ -16,6 +15,7 @@ from tests.integration.common_utils.test_models import DATestUser
class PersonaManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -34,7 +34,6 @@ class PersonaManager:
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[str] | None = None,
user_performing_action: DATestUser | None = None,
display_priority: int | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
@@ -67,11 +66,7 @@ class PersonaManager:
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request.model_dump(mode="json"),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
persona_data = response.json()
@@ -100,6 +95,7 @@ class PersonaManager:
@staticmethod
def edit(
persona: DATestPersona,
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -117,7 +113,6 @@ class PersonaManager:
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
@@ -151,11 +146,7 @@ class PersonaManager:
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request.model_dump(mode="json"),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
updated_persona_data = response.json()
@@ -187,15 +178,11 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [FullPersonaSnapshot(**persona) for persona in response.json()]
@@ -203,15 +190,11 @@ class PersonaManager:
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [FullPersonaSnapshot(**response.json())]
@@ -219,7 +202,7 @@ class PersonaManager:
@staticmethod
def verify(
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
all_personas = PersonaManager.get_one(
persona_id=persona.id,
@@ -388,15 +371,11 @@ class PersonaManager:
@staticmethod
def delete(
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
return response.ok
@@ -405,18 +384,14 @@ class PersonaLabelManager:
@staticmethod
def create(
label: DATestPersonaLabel,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> DATestPersonaLabel:
response = requests.post(
f"{API_SERVER_URL}/persona/labels",
json={
"name": label.name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
response_data = response.json()
@@ -425,15 +400,11 @@ class PersonaLabelManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestPersonaLabel]:
response = requests.get(
f"{API_SERVER_URL}/persona/labels",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [DATestPersonaLabel(**label) for label in response.json()]
@@ -441,18 +412,14 @@ class PersonaLabelManager:
@staticmethod
def update(
label: DATestPersonaLabel,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> DATestPersonaLabel:
response = requests.patch(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
json={
"label_name": label.name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return label
@@ -460,22 +427,18 @@ class PersonaLabelManager:
@staticmethod
def delete(
label: DATestPersonaLabel,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
return response.ok
@staticmethod
def verify(
label: DATestPersonaLabel,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
all_labels = PersonaLabelManager.get_all(user_performing_action)
for fetched_label in all_labels:

View File

@@ -6,7 +6,6 @@ from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.features.projects.models import UserProjectSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -20,7 +19,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/create",
params={"name": name},
headers=user_performing_action.headers or GENERAL_HEADERS,
headers=user_performing_action.headers,
)
response.raise_for_status()
return UserProjectSnapshot.model_validate(response.json())
@@ -32,7 +31,7 @@ class ProjectManager:
"""Get all projects for a user via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
headers=user_performing_action.headers,
)
response.raise_for_status()
return [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -45,7 +44,7 @@ class ProjectManager:
"""Delete a project via API."""
response = requests.delete(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=user_performing_action.headers or GENERAL_HEADERS,
headers=user_performing_action.headers,
)
return response.status_code == 204
@@ -57,7 +56,7 @@ class ProjectManager:
"""Verify that a project has been deleted by ensuring it's not in list."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers or GENERAL_HEADERS,
headers=user_performing_action.headers,
)
response.raise_for_status()
projects = [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -66,16 +65,12 @@ class ProjectManager:
@staticmethod
def verify_files_unlinked(
project_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""Verify that all files have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if response.status_code == 404:
return True
@@ -87,16 +82,12 @@ class ProjectManager:
@staticmethod
def verify_chat_sessions_unlinked(
project_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> bool:
"""Verify that all chat sessions have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if response.status_code == 404:
return True
@@ -144,16 +135,12 @@ class ProjectManager:
@staticmethod
def get_project_files(
project_id: int,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> List[UserFileSnapshot]:
"""Get all files associated with a project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
if response.status_code == 404:
return []
@@ -170,7 +157,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/{project_id}/instructions",
json={"instructions": instructions},
headers=user_performing_action.headers or GENERAL_HEADERS,
headers=user_performing_action.headers,
)
response.raise_for_status()
return (response.json() or {}).get("instructions") or ""

View File

@@ -10,19 +10,18 @@ from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
class QueryHistoryManager:
@staticmethod
def get_query_history_page(
user_performing_action: DATestUser,
page_num: int = 0,
page_size: int = 10,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[ChatSessionMinimal]:
query_params: dict[str, str | int] = {
"page_num": page_num,
@@ -37,11 +36,7 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
data = response.json()
@@ -53,24 +48,20 @@ class QueryHistoryManager:
@staticmethod
def get_chat_session_admin(
chat_session_id: UUID | str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> ChatSessionSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return ChatSessionSnapshot(**response.json())
@staticmethod
def get_query_history_as_csv(
user_performing_action: DATestUser,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> tuple[CaseInsensitiveDict[str], str]:
query_params: dict[str, str | int] = {}
if start_time:
@@ -80,11 +71,7 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.headers, response.content.decode()

View File

@@ -5,7 +5,6 @@ from typing import Optional
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
@@ -13,13 +12,9 @@ from tests.integration.common_utils.test_models import DATestUser
class SettingsManager:
@staticmethod
def get_settings(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> tuple[Dict[str, Any], str]:
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers = user_performing_action.headers
headers.pop("Content-Type", None)
response = requests.get(
@@ -38,13 +33,9 @@ class SettingsManager:
@staticmethod
def update_settings(
settings: DATestSettings,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> tuple[Dict[str, Any], str]:
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers = user_performing_action.headers
headers.pop("Content-Type", None)
payload = settings.model_dump()
@@ -65,7 +56,7 @@ class SettingsManager:
@staticmethod
def get_setting(
key: str,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> Optional[Any]:
settings, error = SettingsManager.get_settings(user_performing_action)
if error:

View File

@@ -8,7 +8,6 @@ from onyx.server.manage.models import AllUsersResponse
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -26,15 +25,11 @@ def generate_auth_token() -> str:
class TenantManager:
@staticmethod
def get_all_users(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> AllUsersResponse:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -50,7 +45,8 @@ class TenantManager:
@staticmethod
def verify_user_in_tenant(
user: DATestUser, user_performing_action: DATestUser | None = None
user: DATestUser,
user_performing_action: DATestUser,
) -> None:
all_users = TenantManager.get_all_users(user_performing_action)
for accepted_user in all_users.accepted:

View File

@@ -1,7 +1,6 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
@@ -9,15 +8,11 @@ from tests.integration.common_utils.test_models import DATestUser
class ToolManager:
@staticmethod
def list_tools(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[DATestTool]:
response = requests.get(
url=f"{API_SERVER_URL}/tool",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [

View File

@@ -7,6 +7,8 @@ import requests
from requests import HTTPError
from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import ANONYMOUS_USER_UUID
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import UserInfo
@@ -25,6 +27,23 @@ def build_email(name: str) -> str:
class UserManager:
@staticmethod
def get_anonymous_user() -> DATestUser:
"""Get a DATestUser representing the anonymous user.
Anonymous users are real users in the database with LIMITED role.
They don't have login cookies - requests are made with GENERAL_HEADERS.
The anonymous_user_enabled setting must be True for these requests to work.
"""
return DATestUser(
id=ANONYMOUS_USER_UUID,
email=ANONYMOUS_USER_EMAIL,
password="",
headers=GENERAL_HEADERS,
role=UserRole.LIMITED,
is_active=True,
)
@staticmethod
def create(
name: str | None = None,
@@ -227,12 +246,12 @@ class UserManager:
@staticmethod
def get_user_page(
user_performing_action: DATestUser,
page_num: int = 0,
page_size: int = 10,
search_query: str | None = None,
role_filter: list[UserRole] | None = None,
is_active_filter: bool | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[FullUserSnapshot]:
query_params: dict[str, str | list[str] | int] = {
"page_num": page_num,
@@ -247,11 +266,7 @@ class UserManager:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()

View File

@@ -5,7 +5,6 @@ import requests
from ee.onyx.server.user_group.models import UserGroup
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import DATestUserGroup
@@ -14,10 +13,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
class UserGroupManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
user_ids: list[str] | None = None,
cc_pair_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestUserGroup:
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
@@ -29,11 +28,7 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group",
json=request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
test_user_group = DATestUserGroup(
@@ -47,31 +42,23 @@ class UserGroupManager:
@staticmethod
def edit(
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
json=user_group.model_dump(),
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def delete(
user_group: DATestUserGroup,
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -79,7 +66,7 @@ class UserGroupManager:
def add_users(
user_group: DATestUserGroup,
user_ids: list[str],
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> DATestUserGroup:
request = {
"user_ids": user_ids,
@@ -88,11 +75,7 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
json=request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -107,8 +90,8 @@ class UserGroupManager:
def set_curator_status(
test_user_group: DATestUserGroup,
user_to_set_as_curator: DATestUser,
user_performing_action: DATestUser,
is_curator: bool = True,
user_performing_action: DATestUser | None = None,
) -> None:
set_curator_request = {
"user_id": user_to_set_as_curator.id,
@@ -117,25 +100,17 @@ class UserGroupManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
json=set_curator_request,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> list[UserGroup]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
headers=user_performing_action.headers,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]
@@ -143,8 +118,8 @@ class UserGroupManager:
@staticmethod
def verify(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_user_groups = UserGroupManager.get_all(user_performing_action)
for fetched_user_group in all_user_groups:
@@ -167,8 +142,8 @@ class UserGroupManager:
@staticmethod
def wait_for_sync(
user_performing_action: DATestUser,
user_groups_to_check: list[DATestUserGroup] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
start = time.time()
while True:
@@ -198,7 +173,7 @@ class UserGroupManager:
@staticmethod
def wait_for_deletion_completion(
user_groups_to_check: list[DATestUserGroup],
user_performing_action: DATestUser | None = None,
user_performing_action: DATestUser,
) -> None:
start = time.time()
user_group_ids_to_check = {user_group.id for user_group in user_groups_to_check}

View File

@@ -88,11 +88,8 @@ def reset() -> None:
@pytest.fixture
def new_admin_user(reset: None) -> DATestUser | None: # noqa: ARG001
try:
return UserManager.create(name=ADMIN_USER_NAME)
except Exception:
return None
def new_admin_user(reset: None) -> DATestUser: # noqa: ARG001
return UserManager.create(name=ADMIN_USER_NAME)
@pytest.fixture
@@ -182,18 +179,18 @@ def reset_multitenant() -> None:
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
@pytest.fixture
def image_generation_config(
admin_user: DATestUser | None,
admin_user: DATestUser,
) -> DATestImageGenerationConfig:
"""Create a default image generation config for tests."""
return ImageGenerationConfigManager.create(
is_default=True,
user_performing_action=admin_user,
is_default=True,
)

View File

@@ -60,3 +60,44 @@ def test_me_endpoint_returns_authenticated_user_info(
assert data.get("is_anonymous_user") is not True
assert data["email"] == admin_user.email
assert data["role"] == "admin"
def test_anonymous_user_can_access_persona_when_enabled(
reset: None, # noqa: ARG001
) -> None:
"""Verify that anonymous users can access limited endpoints when enabled."""
admin_user: DATestUser = UserManager.create(name="admin_user")
SettingsManager.update_settings(
DATestSettings(anonymous_user_enabled=True),
user_performing_action=admin_user,
)
anon_user = UserManager.get_anonymous_user()
response = requests.get(
f"{API_SERVER_URL}/persona",
headers=anon_user.headers,
)
assert response.status_code == 200
def test_anonymous_user_denied_persona_when_disabled(
reset: None, # noqa: ARG001
) -> None:
"""Verify that anonymous users cannot access endpoints when disabled."""
admin_user: DATestUser = UserManager.create(name="admin_user")
SettingsManager.update_settings(
DATestSettings(anonymous_user_enabled=False),
user_performing_action=admin_user,
)
anon_user = UserManager.get_anonymous_user()
response = requests.get(
f"{API_SERVER_URL}/persona",
headers=anon_user.headers,
)
# 403 is returned - BasicAuthenticationError uses HTTP 403 for all auth failures
assert response.status_code == 403

View File

@@ -11,8 +11,8 @@ from tests.integration.common_utils.test_models import DATestUser
def _verify_index_attempt_pagination(
cc_pair_id: int,
index_attempt_ids: list[int],
user_performing_action: DATestUser,
page_size: int = 5,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_attempts: list[int] = []
last_time_started = None # Track the last time_started seen

View File

@@ -207,7 +207,9 @@ def test_mcp_search_respects_acl_filters(
cc_pair_ids=[restricted_cc_pair.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync([user_group], user_performing_action=admin_user)
UserGroupManager.wait_for_sync(
user_performing_action=admin_user, user_groups_to_check=[user_group]
)
restricted_doc_content = "MCP restricted knowledge base document"
_seed_document_and_wait_for_indexing(

View File

@@ -14,11 +14,11 @@ from tests.integration.tests.query_history.utils import (
def _verify_query_history_pagination(
chat_sessions: list[DAQueryHistoryEntry],
user_performing_action: DATestUser,
page_size: int = 5,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_sessions: list[str] = []

View File

@@ -5,7 +5,6 @@ import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@@ -59,7 +58,7 @@ def test_add_users_to_group_invalid_user(reset: None) -> None: # noqa: ARG001
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users",
json={"user_ids": [invalid_user_id]},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
headers=admin_user.headers,
)
assert response.status_code == 404

View File

@@ -9,11 +9,11 @@ from tests.integration.common_utils.test_models import DATestUser
# to verify that the pagination and filtering works as expected.
def _verify_user_pagination(
users: list[DATestUser],
user_performing_action: DATestUser,
page_size: int = 5,
search_query: str | None = None,
role_filter: list[UserRole] | None = None,
is_active_filter: bool | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_users: list[FullUserSnapshot] = []

View File

@@ -158,14 +158,14 @@ python ./scripts/dev_run_background_jobs.py
To run the backend API server, navigate back to `onyx/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
$env:AUTH_TYPE='disabled'
$env:AUTH_TYPE='basic'
uvicorn onyx.main:app --reload --port 8080
"
```

View File

@@ -126,7 +126,9 @@ Resources:
- Effect: Allow
Action:
- secretsmanager:GetSecretValue
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
Resource:
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
Outputs:
OutputEcsCluster:

View File

@@ -167,10 +167,12 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: disabled
Value: basic
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -166,9 +166,11 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: disabled
Value: basic
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -21,7 +21,7 @@ services:
env_file:
- .env_eval
environment:
- AUTH_TYPE=disabled
- AUTH_TYPE=basic
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- REDIS_HOST=cache
@@ -59,7 +59,7 @@ services:
- .env_eval
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=disabled
- AUTH_TYPE=basic
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- REDIS_HOST=cache

View File

@@ -20,8 +20,11 @@ IMAGE_TAG=latest
## Auth Settings
### https://docs.onyx.app/deployment/authentication
AUTH_TYPE=disabled
AUTH_TYPE=basic
# SESSION_EXPIRE_TIME_SECONDS=
### Required for basic auth - used for signing password reset and verification tokens
### Generate a secure value with: openssl rand -hex 32
USER_AUTH_SECRET=OnyxDevSecret1!
### Recommend to set this for security
# ENCRYPTION_KEY_SECRET=
### Optional

View File

@@ -654,16 +654,9 @@ else
sed -i.bak "s/^IMAGE_TAG=.*/IMAGE_TAG=$VERSION/" "$ENV_FILE"
print_success "IMAGE_TAG set to $VERSION"
# Configure authentication settings based on selection
if [ "$AUTH_SCHEMA" = "disabled" ]; then
# Disable authentication in .env file
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=disabled/' "$ENV_FILE" 2>/dev/null || true
print_success "Authentication disabled in configuration"
else
# Enable basic authentication
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
print_success "Basic authentication enabled in configuration"
fi
# Configure basic authentication (default)
sed -i.bak 's/^AUTH_TYPE=.*/AUTH_TYPE=basic/' "$ENV_FILE" 2>/dev/null || true
print_success "Basic authentication enabled in configuration"
# Configure Craft based on flag or if using a craft-* image tag
# By default, env.template has Craft commented out (disabled)

View File

@@ -1167,10 +1167,25 @@ auth:
values:
opensearch_admin_username: "admin"
opensearch_admin_password: "OnyxDev1!"
userauth:
# -- Required when AUTH_TYPE is "basic". Used for signing password reset
# tokens, email verification tokens, and JWT tokens.
enabled: true
# -- Overwrite the default secret name, ignored if existingSecret is defined
secretName: 'onyx-userauth'
# -- Use a secret specified elsewhere
existingSecret: ""
# -- This defines the env var to secret map
secretKeys:
USER_AUTH_SECRET: user_auth_secret
# -- Secret value. CHANGE THIS FOR PRODUCTION.
# Generate a secure value with: openssl rand -hex 32
values:
user_auth_secret: "OnyxDevSecret1!"
configMap:
# Change this for production uses unless Onyx is only accessible behind VPN
AUTH_TYPE: "disabled"
AUTH_TYPE: "basic"
# 1 Day Default
SESSION_EXPIRE_TIME_SECONDS: "86400"
# Can be something like onyx.app, as an extra double-check