Compare commits

..

7 Commits

32 changed files with 942 additions and 759 deletions

View File

@@ -58,16 +58,27 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -115,15 +126,26 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -141,8 +163,13 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": oauth_config.client_id,
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -161,6 +188,12 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,

View File

@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a

View File

@@ -488,22 +488,6 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:

View File

@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from shared_configs.configs import MULTI_TENANT
@@ -49,8 +50,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
@@ -118,8 +122,11 @@ def get_all_document_indices(
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,

View File

@@ -1,5 +1,7 @@
import logging
import time
from contextlib import AbstractContextManager
from contextlib import nullcontext
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
return new_body
class OpenSearchClient:
"""Client for interacting with OpenSearch.
class OpenSearchClient(AbstractContextManager):
"""Client for interacting with OpenSearch for cluster-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
@@ -107,9 +113,8 @@ class OpenSearchClient:
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
self._index_name = index_name
logger.debug(
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
)
self._client = OpenSearch(
hosts=[{"host": host, "port": port}],
@@ -125,6 +130,142 @@ class OpenSearchClient:
# your request body that is less than this value.
timeout=timeout,
)
def __exit__(self, *_: Any) -> None:
self.close()
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
class OpenSearchIndexClient(OpenSearchClient):
"""Client for interacting with OpenSearch for index-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
index_name: The name of the index to interact with.
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
super().__init__(
host=host,
port=port,
auth=auth,
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_show_warn=ssl_show_warn,
timeout=timeout,
)
self._index_name = index_name
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -192,6 +333,38 @@ class OpenSearchClient:
"""
return self._client.indices.exists(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_mapping(self, mappings: dict[str, Any]) -> None:
"""Updates the index mapping in an idempotent manner.
- Existing fields with the same definition: No-op (succeeds silently).
- New fields: Added to the index.
- Existing fields with different types: Raises exception (requires
reindex).
See the OpenSearch documentation for more information:
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
Args:
mappings: The complete mapping definition to apply. This will be
merged with existing mappings in the index.
Raises:
Exception: There was an error updating the mappings, such as
attempting to change the type of an existing field.
"""
logger.debug(
f"Putting mappings for index {self._index_name} with mappings {mappings}."
)
response = self._client.indices.put_mapping(
index=self._index_name, body=mappings
)
if not response.get("acknowledged", False):
raise RuntimeError(
f"Failed to put the mapping update for index {self._index_name}."
)
logger.debug(f"Successfully put mappings for index {self._index_name}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
"""Validates the index.
@@ -610,43 +783,6 @@ class OpenSearchClient:
)
return DocumentChunk.model_validate(document_chunk_source)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
@@ -807,48 +943,6 @@ class OpenSearchClient:
"""
self._client.indices.refresh(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
TODO(andrei): Can we have some way to auto close when the client no
longer has any references?
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
def _get_hits_and_profile_from_search_result(
self, result: dict[str, Any]
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
Returns:
True if OpenSearch is ready, False otherwise.
"""
made_client = False
try:
if client is None:
# NOTE: index_name does not matter because we are only using this object
# to ping.
# TODO(andrei): Make this better.
client = OpenSearchClient(index_name="")
made_client = True
with nullcontext(client) if client else OpenSearchClient() as client:
time_start = time.monotonic()
while True:
if client.ping():
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
)
time.sleep(wait_interval_s)
finally:
if made_client:
assert client is not None
client.close()

View File

@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
return list(access_control_list)
def set_cluster_state(client: OpenSearchClient) -> None:
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
logger.error(
"Failed to put cluster settings. If the settings have never been set before, "
"this may cause unexpected index creation when indexing documents into an "
"index that does not exist, or may cause expected logs to not appear. If this "
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def __init__(
self,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
index_name=index_name,
secondary_index_name=secondary_index_name,
)
if multitenant:
raise ValueError(
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
)
if multitenant != MULTI_TENANT:
raise ValueError(
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
)
tenant_id = get_current_tenant_id()
self._real_index = OpenSearchDocumentIndex(
index_name=index_name,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
@staticmethod
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
# TODO(andrei): Implement.
raise NotImplementedError(
"Multitenant index registration is not yet implemented for OpenSearch."
"Bug: Multitenant index registration is not supported for OpenSearch."
)
def ensure_indices_exist(
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
for an OpenSearch search engine instance. It handles the complete lifecycle
of document chunks within a specific OpenSearch index/schema.
Although not yet used in this way in the codebase, each kind of embedding
used should correspond to a different instance of this class, and therefore
a different index in OpenSearch.
Each kind of embedding used should correspond to a different instance of
this class, and therefore a different index in OpenSearch.
If in a multitenant environment and
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
if necessary on initialization. This is because there is no logic which runs
on cluster restart which scans through all search settings over all tenants
and creates the relevant indices.
Args:
tenant_state: The tenant state of the caller.
index_name: The name of the index to interact with.
embedding_dim: The dimensionality of the embeddings used for the index.
embedding_precision: The precision of the embeddings used for the index.
"""
def __init__(
self,
index_name: str,
tenant_state: TenantState,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> None:
self._index_name: str = index_name
self._tenant_state: TenantState = tenant_state
self._os_client = OpenSearchClient(index_name=self._index_name)
self._client = OpenSearchIndexClient(index_name=self._index_name)
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
self.verify_and_create_index_if_necessary(
embedding_dim=embedding_dim, embedding_precision=embedding_precision
)
def verify_and_create_index_if_necessary(
self,
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
) -> None:
"""Verifies and creates the index if necessary.
Also puts the desired cluster settings.
Also puts the desired cluster settings if not in a multitenant
environment.
Also puts the desired search pipeline state, creating the pipelines if
they do not exist and updating them otherwise.
Also puts the desired search pipeline state if not in a multitenant
environment, creating the pipelines if they do not exist and updating
them otherwise.
In a multitenant environment, the above steps happen explicitly on
setup.
Args:
embedding_dim: Vector dimensionality for the vector similarity part
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
search pipelines.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
f"with embedding dimension {embedding_dim}."
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
f"necessary, with embedding dimension {embedding_dim}."
)
if not self._tenant_state.multitenant:
set_cluster_state(self._client)
expected_mappings = DocumentSchema.get_document_schema(
embedding_dim, self._tenant_state.multitenant
)
if not self._os_client.put_cluster_settings(
settings=OPENSEARCH_CLUSTER_SETTINGS
):
logger.error(
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
"these settings have likely already been set. Not taking any further action..."
)
if not self._os_client.index_exists():
if not self._client.index_exists():
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
self._os_client.create_index(
self._client.create_index(
mappings=expected_mappings,
settings=index_settings,
)
if not self._os_client.validate_index(
expected_mappings=expected_mappings,
):
raise RuntimeError(
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
)
self._os_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
self._os_client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
else:
# Ensure schema is up to date by applying the current mappings.
try:
self._client.put_mapping(expected_mappings)
except Exception as e:
logger.error(
f"Failed to update mappings for index {self._index_name}. This likely means a "
f"field type was changed which requires reindexing. Error: {e}"
)
raise
def index(
self,
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
tenant_state=self._tenant_state,
)
return self._os_client.delete_by_query(query_body)
return self._client.delete_by_query(query_body)
def update(
self,
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
document_id=doc_id,
chunk_index=chunk_index,
)
self._os_client.update_document(
self._client.update_document(
document_chunk_id=document_chunk_id,
properties_to_update=properties_to_update,
)
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
)
search_hits = self._os_client.search(
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
# Do not raise if the document already exists, just update. This is
# because the document may already have been indexed during the
# OpenSearch transition period.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
)

View File

@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -55,10 +52,8 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
@@ -238,9 +233,12 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -270,7 +268,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.model,
model=test_llm_request.default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -305,7 +303,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderView]:
) -> list[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -330,15 +328,7 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return llm_provider_list
@admin_router.put("/provider")
@@ -526,7 +516,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[VisionProviderResponse]:
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -555,13 +545,7 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return vision_provider_response
"""Endpoints for all"""
@@ -571,7 +555,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> LLMProviderResponse[LLMProviderDescriptor]:
) -> list[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -608,15 +592,7 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
return accessible_providers
def get_valid_model_names_for_persona(

View File

@@ -1,9 +1,5 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -25,8 +21,6 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView")
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
@@ -58,17 +52,19 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
class TestLLMRequest(BaseModel):
# provider level
id: int | None = None
name: str | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -429,38 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
@@ -32,6 +33,9 @@ from onyx.db.search_settings import update_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.key_value_store.factory import get_kv_store
@@ -311,7 +315,14 @@ def setup_multitenant_onyx() -> None:
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
return
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
opensearch_client = OpenSearchClient()
if not wait_for_opensearch_with_timeout(client=opensearch_client):
raise RuntimeError("Failed to connect to OpenSearch.")
set_cluster_state(opensearch_client)
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
# NOTE: Pretty sure this code is never hit in any production environment.
if not MANAGED_VESPA:
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"model": _DEFAULT_BEDROCK_MODEL,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -45,9 +44,9 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> LLMProviderView:
) -> None:
"""Helper to create a test LLM provider in the database."""
return upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
@@ -103,11 +102,17 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -147,11 +152,17 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -183,9 +194,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
_create_test_provider(db_session, provider_name, api_key=original_api_key)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -193,12 +202,17 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name, # Existing provider
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -232,9 +246,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
_create_test_provider(db_session, provider_name, api_key=original_api_key)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -242,12 +254,17 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name, # Existing provider
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -280,7 +297,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
provider = upsert_llm_provider(
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -304,13 +321,18 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
model="gpt-4o-mini",
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,
@@ -346,11 +368,17 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
model=model_name,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
),
_=_create_mock_admin(),
db_session=db_session,

View File

@@ -530,8 +530,14 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -540,7 +546,7 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
@@ -563,9 +569,14 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
id=provider.id,
name=name,
provider=provider_name,
model=default_model_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=False,
),
@@ -605,13 +616,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
assert False, "Provider not found in the database"
assert db_provider.custom_config == custom_config, (
assert provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {db_provider.custom_config}"
f"but got {provider.custom_config}"
)
finally:
db_session.rollback()
@@ -695,7 +706,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider_name = LlmProviderNames.VERTEX_AI.value
provider = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -708,10 +719,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
provider = put_llm_provider(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
@@ -731,9 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
id=provider.id,
provider=provider_name,
model=default_model_name,
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -1,4 +1,4 @@
"""External dependency unit tests for OpenSearchClient.
"""External dependency unit tests for OpenSearchIndexClient.
These tests assume OpenSearch is running and test all implemented methods
using real schemas, pipelines, and search queries from the codebase.
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.opensearch_document_index import (
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
@pytest.fixture(scope="function")
def test_client(
opensearch_available: None, # noqa: ARG001
) -> Generator[OpenSearchClient, None, None]:
) -> Generator[OpenSearchIndexClient, None, None]:
"""Creates an OpenSearch client for testing with automatic cleanup."""
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
client = OpenSearchClient(index_name=test_index_name)
client = OpenSearchIndexClient(index_name=test_index_name)
yield client # Test runs here.
@@ -142,7 +142,7 @@ def test_client(
@pytest.fixture(scope="function")
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
"""Creates a search pipeline for testing with automatic cleanup."""
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
class TestOpenSearchClient:
"""Tests for OpenSearchClient."""
"""Tests for OpenSearchIndexClient."""
def test_create_index(self, test_client: OpenSearchClient) -> None:
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests creating an index with a real schema."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
# Verify index exists.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests deleting an existing index returns True."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
assert result is True
assert test_client.validate_index(expected_mappings=mappings) is False
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests deleting a nonexistent index returns False."""
# Under test.
# Don't create index, just try to delete.
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
# Postcondition.
assert result is False
def test_index_exists(self, test_client: OpenSearchClient) -> None:
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
"""Tests checking if an index exists."""
# Precondition.
# Index should not exist before creation.
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
# Index should exist after creation.
assert test_client.index_exists() is True
def test_validate_index(self, test_client: OpenSearchClient) -> None:
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests validating an index."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
# Should return True after creation.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
"""Tests put_mapping with same schema is idempotent."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
# Applying the same mappings again should succeed.
test_client.put_mapping(mappings)
# Postcondition.
# Index should still be valid.
assert test_client.validate_index(expected_mappings=mappings)
def test_put_mapping_adds_new_field(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping successfully adds new fields to existing index."""
# Precondition.
# Create index with minimal schema (just required fields).
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
# Add a new field using put_mapping.
updated_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
# New field
"new_test_field": {"type": "keyword"},
},
}
# Should not raise.
test_client.put_mapping(updated_mappings)
# Postcondition.
# Validate the new schema includes the new field.
assert test_client.validate_index(expected_mappings=updated_mappings)
def test_put_mapping_fails_on_type_change(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping fails when trying to change existing field type."""
# Precondition.
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
# Try to change test_field type from keyword to text.
conflicting_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "text"}, # Changed from keyword to text
},
}
# Should raise because field type cannot be changed.
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
test_client.put_mapping(conflicting_mappings)
def test_put_mapping_on_nonexistent_index(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping on non-existent index raises an error."""
# Precondition.
# Index does not exist yet.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.put_mapping(mappings)
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
"""Tests creating an index twice raises an error."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
with pytest.raises(Exception, match="already exists"):
test_client.create_index(mappings=mappings, settings=settings)
def test_update_settings(self, test_client: OpenSearchClient) -> None:
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
"""Tests that update_settings raises NotImplementedError."""
# Under test and postcondition.
with pytest.raises(NotImplementedError):
test_client.update_settings(settings={})
def test_create_and_delete_search_pipeline(
self, test_client: OpenSearchClient
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests creating and deleting a search pipeline."""
# Under test and postcondition.
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
)
def test_index_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a document."""
# Precondition.
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
)
def test_bulk_index_documents(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests bulk indexing documents."""
# Precondition.
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
)
def test_index_duplicate_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a duplicate document raises an error."""
# Precondition.
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
test_client.index_document(document=doc, tenant_state=tenant_state)
def test_get_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a document."""
# Precondition.
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
assert retrieved_doc == original_doc
def test_get_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a nonexistent document raises an error."""
# Precondition.
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
)
def test_delete_existing_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting an existing document returns True."""
# Precondition.
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
test_client.get_document(document_chunk_id=doc_chunk_id)
def test_delete_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting a nonexistent document returns False."""
# Precondition.
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
assert result is False
def test_delete_by_query(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting documents by query."""
# Precondition.
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
assert len(keep_ids) == 1
def test_update_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a document's properties."""
# Precondition.
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
assert updated_doc.public == doc.public
def test_update_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a nonexistent document raises an error."""
# Precondition.
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
def test_search_empty_index(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
previous_score = current_score
def test_delete_by_query_multitenant_isolation(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
assert set(remaining_y_ids) == expected_y_ids
def test_delete_by_query_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query for non-existent document returns 0 deleted.
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
assert num_deleted == 0
def test_search_for_document_ids(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests search_for_document_ids method returns correct chunk IDs."""
# Precondition.
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
assert set(chunk_ids) == expected_ids
def test_search_with_no_document_access_can_retrieve_all_documents(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests search with no document access can retrieve all documents, even
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
def test_time_cutoff_filter(
self,
test_client: OpenSearchClient,
test_client: OpenSearchIndexClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
)
def test_random_search(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests the random search query works."""
# Precondition.

View File

@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
def _delete_document_chunks_from_opensearch(
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
) -> None:
opensearch_client.refresh_index()
query_body = DocumentQuery.delete_from_document_id_query(
@@ -283,10 +284,10 @@ def vespa_document_index(
def opensearch_client(
db_session: Session,
full_deployment_setup: None, # noqa: ARG001
) -> Generator[OpenSearchClient, None, None]:
) -> Generator[OpenSearchIndexClient, None, None]:
"""Creates an OpenSearch client for the test tenant."""
active = get_active_search_settings(db_session)
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
@pytest.fixture(scope="module")
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
def test_documents(
db_session: Session,
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
patch_get_vespa_chunks_page_size: int,
) -> Generator[list[Document], None, None]:
"""
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchClient,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002

View File

@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
from onyx.db.models import OAuthConfig
from onyx.db.oauth_config import create_oauth_config
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.sensitive import SensitiveValue
from tests.external_dependency_unit.conftest import create_test_user
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
# Should use & instead of ? since URL already has query params
assert "foo=bar&" in url or "?foo=bar" in url
assert "client_id=custom_client_id" in url
class TestUnwrapSensitiveStr:
"""Tests for _unwrap_sensitive_str static method"""
def test_unwrap_sensitive_str(self) -> None:
"""Test that both SensitiveValue and plain str inputs are handled"""
# SensitiveValue input
sensitive = SensitiveValue[str](
encrypted_bytes=b"test_client_id",
decrypt_fn=lambda b: b.decode(),
)
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
# Plain str input
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"

View File

@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json()["providers"]:
for provider in response.json():
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")

View File

@@ -23,7 +23,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
return next((p for p in providers if p["id"] == provider_id), None)
@@ -578,57 +578,6 @@ def test_model_visibility_preserved_on_edit(reset: None) -> None: # noqa: ARG00
assert visible_models[0]["name"] == "gpt-4o"
def _get_provider_by_name(providers: list[dict], provider_name: str) -> dict | None:
return next((p for p in providers if p["name"] == provider_name), None)
def _get_providers_admin(
admin_user: DATestUser,
) -> dict | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
resp_json = response.json()
return resp_json
def _unpack_data(data: dict) -> tuple[list[dict], dict | None, dict | None]:
providers = data["providers"]
text_default = data.get("default_text")
vision_default = data.get("default_vision")
return providers, text_default, vision_default
def _get_providers_basic(
user: DATestUser,
) -> dict | None:
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=user.headers,
)
assert response.status_code == 200
resp_json = response.json()
return resp_json
def _validate_default_model(
default: dict | None,
provider_id: int | None = None,
model_name: str | None = None,
) -> None:
if default is None:
assert provider_id is None and model_name is None
return
assert default["provider_id"] == provider_id
assert default["model_name"] == model_name
def _get_provider_by_name_admin(
admin_user: DATestUser, provider_name: str
) -> dict | None:
@@ -649,7 +598,7 @@ def _get_provider_by_name_basic(user: DATestUser, provider_name: str) -> dict |
headers=user.headers,
)
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
return next((p for p in providers if p["name"] == provider_name), None)
@@ -833,22 +782,9 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
assert create_response.status_code == 200
# Capture initial defaults (setup_postgres may have created a DevEnvPresetOpenAI default)
initial_data = _get_providers_admin(admin_user)
assert initial_data is not None
_, initial_text_default, initial_vision_default = _unpack_data(initial_data)
# Step 2: Verify via admin endpoint that all provider data is correct
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
# Defaults should be unchanged from initial state (new provider not set as default)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
admin_provider_data = _get_provider_by_name(providers, provider_name)
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_provider_data,
expected_name=provider_name,
@@ -861,13 +797,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 3: Verify via basic endpoint (admin user) that all provider data is correct
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_provider_data,
@@ -880,13 +810,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 4: Verify non-admin user sees the same provider data via basic endpoint
basic_user_data = _get_providers_basic(basic_user)
assert basic_user_data is not None
providers, text_default, vision_default = _unpack_data(basic_user_data)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
assert basic_user_provider_data is not None
_validate_provider_data(
basic_user_provider_data,
@@ -922,17 +846,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
assert default_provider_response.status_code == 200
# Step 6a: Verify the updated provider data via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_name)
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_provider_data,
@@ -946,17 +860,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 6b: Verify the updated provider data via basic endpoint (admin user)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_provider_data,
@@ -969,17 +873,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 7: Verify non-admin user sees the updated provider data
basic_user_data = _get_providers_basic(basic_user)
assert basic_user_data is not None
providers, text_default, vision_default = _unpack_data(basic_user_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
assert basic_user_provider_data is not None
_validate_provider_data(
basic_user_provider_data,
@@ -999,7 +893,7 @@ def _get_all_providers_basic(user: DATestUser) -> list[dict]:
headers=user.headers,
)
assert response.status_code == 200
return response.json()["providers"]
return response.json()
def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
@@ -1009,7 +903,7 @@ def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
headers=admin_user.headers,
)
assert response.status_code == 200
return response.json()["providers"]
return response.json()
def _set_default_provider(admin_user: DATestUser, provider_id: int) -> None:
@@ -1145,17 +1039,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 3: Both admin and basic_user query and verify they see the same default
# Validate via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_1_name)
assert admin_provider_data is not None
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
_validate_provider_data(
admin_provider_data,
admin_default,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1166,7 +1054,9 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate provider 2 via admin endpoint (should not be default)
admin_provider_2 = _get_provider_by_name(providers, provider_2_name)
admin_provider_2 = next(
(p for p in admin_providers if p["name"] == provider_2_name), None
)
assert admin_provider_2 is not None
_validate_provider_data(
admin_provider_2,
@@ -1180,17 +1070,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_1_name)
assert basic_provider_data is not None
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
_validate_provider_data(
basic_provider_data,
basic_default,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1200,17 +1084,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Also verify admin sees the same via basic endpoint
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_1_name)
assert admin_basic_provider_data is not None
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
_validate_provider_data(
admin_basic_provider_data,
admin_basic_default,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1243,17 +1121,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 5: Both admin and basic_user verify they see the updated default
# Validate via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_provider_data is not None
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
_validate_provider_data(
admin_provider_data,
admin_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1264,7 +1136,9 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate provider 1 via admin endpoint (should no longer be default)
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
admin_provider_1 = next(
(p for p in admin_providers if p["name"] == provider_1_name), None
)
assert admin_provider_1 is not None
_validate_provider_data(
admin_provider_1,
@@ -1278,17 +1152,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert basic_provider_data is not None
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
_validate_provider_data(
basic_provider_data,
basic_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1298,17 +1166,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (admin_user)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_basic_provider_data is not None
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
_validate_provider_data(
admin_basic_provider_data,
admin_basic_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1338,17 +1200,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 7: Both users verify they see provider 2 as default with the shared model name
# Validate via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_provider_data is not None
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
_validate_provider_data(
admin_provider_data,
admin_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1359,17 +1215,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert basic_provider_data is not None
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
_validate_provider_data(
basic_provider_data,
basic_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1379,17 +1229,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (admin_user)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_basic_provider_data is not None
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
_validate_provider_data(
admin_basic_provider_data,
admin_basic_default,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1399,10 +1243,12 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Verify provider 1 is no longer the default and has correct data
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
assert admin_provider_1 is not None
provider_1_admin = next(
(p for p in admin_providers if p["name"] == provider_1_name), None
)
assert provider_1_admin is not None
_validate_provider_data(
admin_provider_1,
provider_1_admin,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1412,10 +1258,12 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
expected_is_public=True,
)
basic_provider_1 = _get_provider_by_name(providers, provider_1_name)
assert basic_provider_1 is not None
provider_1_basic = next(
(p for p in basic_providers if p["name"] == provider_1_name), None
)
assert provider_1_basic is not None
_validate_provider_data(
basic_provider_1,
provider_1_basic,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1543,22 +1391,10 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 5: Verify via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
admin_providers = _get_all_providers_admin(admin_user)
# Find and validate the default provider (provider 1)
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
admin_default = _get_provider_by_name(providers, provider_1_name)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
_validate_provider_data(
admin_default,
@@ -1573,7 +1409,7 @@ def test_default_provider_and_vision_provider_selection(
)
# Find and validate the default vision provider (provider 2)
admin_vision_default = _get_provider_by_name(providers, provider_2_name)
admin_vision_default = _find_default_vision_provider(admin_providers)
assert admin_vision_default is not None
_validate_provider_data(
admin_vision_default,
@@ -1589,21 +1425,10 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 6: Verify via basic endpoint (basic_user)
basic_providers = _get_all_providers_basic(basic_user)
# Find and validate the default provider (provider 1)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
basic_default = _get_provider_by_name(providers, provider_1_name)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
_validate_provider_data(
basic_default,
@@ -1617,7 +1442,7 @@ def test_default_provider_and_vision_provider_selection(
)
# Find and validate the default vision provider (provider 2)
basic_vision_default = _get_provider_by_name(providers, provider_2_name)
basic_vision_default = _find_default_vision_provider(basic_providers)
assert basic_vision_default is not None
_validate_provider_data(
basic_vision_default,
@@ -1632,20 +1457,9 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 7: Verify via basic endpoint (admin_user sees same as basic_user)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
admin_basic_default = _get_provider_by_name(providers, provider_1_name)
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
_validate_provider_data(
admin_basic_default,
@@ -1658,7 +1472,7 @@ def test_default_provider_and_vision_provider_selection(
expected_is_default_vision=None,
)
admin_basic_vision_default = _get_provider_by_name(providers, provider_2_name)
admin_basic_vision_default = _find_default_vision_provider(admin_basic_providers)
assert admin_basic_vision_default is not None
_validate_provider_data(
admin_basic_vision_default,
@@ -1735,14 +1549,7 @@ def test_default_provider_is_not_default_vision_provider(
_set_default_provider(admin_user, created_provider["id"])
# Step 3 & 4: Verify via admin endpoint
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=created_provider["id"], model_name="gpt-4"
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_name)
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
assert admin_provider_data is not None
# Verify it IS the default provider
@@ -1777,14 +1584,7 @@ def test_default_provider_is_not_default_vision_provider(
)
# Also verify via basic endpoint
basic_data = _get_providers_basic(admin_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=created_provider["id"], model_name="gpt-4"
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_name)
basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
assert basic_provider_data is not None
assert (
@@ -1969,52 +1769,37 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
# Step 4: Verify all three types are correctly tracked
# Get all LLM providers (via admin endpoint)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
)
_validate_default_model(
vision_default,
provider_id=vision_provider["id"],
model_name="gpt-4-vision-preview",
)
_validate_default_model(
vision_default, vision_provider["id"], "gpt-4-vision-preview"
)
_get_provider_by_name(providers, regular_provider_name)
admin_providers = _get_all_providers_admin(admin_user)
# Get all image generation configs
image_gen_configs = _get_all_image_gen_configs(admin_user)
# Verify the regular provider is the default provider
admin_regular_provider_data = _get_provider_by_name(
providers, regular_provider_name
regular_provider_data = next(
(p for p in admin_providers if p["name"] == regular_provider_name), None
)
assert admin_regular_provider_data is not None
_validate_provider_data(
admin_regular_provider_data,
expected_name=regular_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in regular_model_configs],
expected_visible={c.name: True for c in regular_model_configs},
expected_default_model="gpt-4",
expected_is_default=True,
)
admin_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
assert admin_vision_provider_data is not None
_validate_provider_data(
admin_vision_provider_data,
expected_name=vision_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in vision_model_configs],
expected_visible={c.name: True for c in vision_model_configs},
expected_default_model="gpt-4-vision-preview",
expected_is_default=False,
expected_default_vision_model="gpt-4-vision-preview",
expected_is_default_vision=True,
assert regular_provider_data is not None, "Regular provider not found"
assert (
regular_provider_data["is_default_provider"] is True
), "Regular provider should be the default provider"
assert (
regular_provider_data.get("is_default_vision_provider") is not True
), "Regular provider should NOT be the default vision provider"
# Verify the vision provider is the default vision provider
vision_provider_data = next(
(p for p in admin_providers if p["name"] == vision_provider_name), None
)
assert vision_provider_data is not None, "Vision provider not found"
assert (
vision_provider_data.get("is_default_provider") is not True
), "Vision provider should NOT be the default provider"
assert (
vision_provider_data["is_default_vision_provider"] is True
), "Vision provider should be the default vision provider"
assert (
vision_provider_data["default_vision_model"] == "gpt-4-vision-preview"
), "Vision provider should have correct default vision model"
# Verify the image gen config is the default image generation config
image_gen_config_data = next(
@@ -2034,53 +1819,97 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
), "Image gen config should have correct model name"
# Step 5: Verify no mixup - image gen providers don't appear in LLM provider lists
# Image gen provider should not appear in the list
assert image_gen_provider_id not in [p["name"] for p in providers]
# The image gen config creates an LLM provider with name "Image Gen - {image_provider_id}"
# This should NOT be returned by the regular LLM provider endpoints
[p["name"] for p in admin_providers]
image_gen_llm_provider_name = f"Image Gen - {image_gen_provider_id}"
# Note: The image gen provider IS an LLM provider internally, so it may appear in the list
# But it should NOT be marked as default provider or default vision provider
image_gen_llm_provider = next(
(p for p in admin_providers if p["name"] == image_gen_llm_provider_name), None
)
if image_gen_llm_provider:
# If it appears, verify it's not marked as default for either type
assert (
image_gen_llm_provider.get("is_default_provider") is not True
), "Image gen's internal LLM provider should NOT be the default provider"
assert (
image_gen_llm_provider.get("is_default_vision_provider") is not True
), "Image gen's internal LLM provider should NOT be the default vision provider"
# Step 6: Verify via basic endpoint (non-admin user)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
basic_providers = _get_all_providers_basic(basic_user)
# Verify regular provider is default for basic user
basic_regular = next(
(p for p in basic_providers if p["name"] == regular_provider_name), None
)
_validate_default_model(
vision_default,
provider_id=vision_provider["id"],
model_name="gpt-4-vision-preview",
)
_validate_default_model(
vision_default, vision_provider["id"], "gpt-4-vision-preview"
)
basic_provider_data = _get_provider_by_name(providers, regular_provider_name)
assert basic_provider_data is not None
_validate_provider_data(
basic_provider_data,
expected_name=regular_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in regular_model_configs],
expected_visible={c.name: True for c in regular_model_configs},
expected_is_default=True,
expected_default_model="gpt-4",
)
basic_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
assert basic_vision_provider_data is not None
_validate_provider_data(
basic_vision_provider_data,
expected_name=vision_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in vision_model_configs],
expected_visible={c.name: True for c in vision_model_configs},
expected_is_default=False,
expected_default_model="gpt-4-vision-preview",
expected_default_vision_model="gpt-4-vision-preview",
expected_is_default_vision=True,
assert basic_regular is not None, "Regular provider not visible to basic user"
assert (
basic_regular["is_default_provider"] is True
), "Regular provider should be default for basic user"
# Verify vision provider is default vision for basic user
basic_vision = next(
(p for p in basic_providers if p["name"] == vision_provider_name), None
)
assert basic_vision is not None, "Vision provider not visible to basic user"
assert (
basic_vision["is_default_vision_provider"] is True
), "Vision provider should be default vision for basic user"
# Step 7: Verify the counts are as expected
# We should have at least 2 user-created providers (setup_postgres may add more)
assert len(providers) >= 2
assert len(image_gen_configs) == 1
# We should have at least 2 user-created providers plus the image gen internal provider
user_created_providers = [
p
for p in admin_providers
if p["name"] in [regular_provider_name, vision_provider_name]
]
assert (
len(user_created_providers) == 2
), f"Expected 2 user-created providers, got {len(user_created_providers)}"
# We should have exactly 1 image gen config
assert (
len(
[
c
for c in image_gen_configs
if c["image_provider_id"] == image_gen_provider_id
]
)
== 1
), "Expected exactly 1 image gen config with our ID"
# Verify that our explicitly created providers are tracked correctly:
# - Only ONE provider has is_default_provider=True
default_providers = [
p for p in admin_providers if p.get("is_default_provider") is True
]
assert (
len(default_providers) == 1
), f"Expected exactly 1 default provider, got {len(default_providers)}"
assert default_providers[0]["name"] == regular_provider_name
# - Only ONE provider has is_default_vision_provider=True
default_vision_providers = [
p for p in admin_providers if p.get("is_default_vision_provider") is True
]
assert (
len(default_vision_providers) == 1
), f"Expected exactly 1 default vision provider, got {len(default_vision_providers)}"
assert default_vision_providers[0]["name"] == vision_provider_name
# - Only ONE image gen config has is_default=True
default_image_gen_configs = [
c for c in image_gen_configs if c.get("is_default") is True
]
assert (
len(default_image_gen_configs) == 1
), f"Expected exactly 1 default image gen config, got {len(default_image_gen_configs)}"
assert default_image_gen_configs[0]["image_provider_id"] == image_gen_provider_id
# Clean up: Delete the image gen config (to clean up the internal LLM provider)
_delete_image_gen_config(admin_user, image_gen_provider_id)

View File

@@ -272,19 +272,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
# First, clear any existing CHAT defaults (setup_postgres may have created one)
existing_defaults = (
db_session.query(LLMModelFlow)
.filter(
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
LLMModelFlow.is_default == True, # noqa: E712
)
.all()
)
for existing in existing_defaults:
existing.is_default = False
db_session.flush()
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
@@ -378,7 +365,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()["providers"]
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -393,7 +380,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()["providers"]
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names

View File

@@ -12,7 +12,7 @@ const SvgOrganization = ({ size, ...props }: IconProps) => (
>
<path
d="M7.5 14H13.5C14.0523 14 14.5 13.5523 14.5 13V6C14.5 5.44772 14.0523 5 13.5 5H7.5M7.5 14V11M7.5 14H4.5M7.5 5V3C7.5 2.44772 7.05228 2 6.5 2H4.5M7.5 5H1.5M7.5 5V8M1.5 5V3C1.5 2.44772 1.94772 2 2.5 2H4.5M1.5 5V8M7.5 8V11M7.5 8H4.5M1.5 8V11M1.5 8H4.5M7.5 11H4.5M1.5 11V13C1.5 13.5523 1.94772 14 2.5 14H4.5M1.5 11H4.5M4.5 2V8M4.5 14V11M4.5 11V8M10 8H12M10 11H12"
strokeWidth={1}
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>

View File

@@ -17,6 +17,7 @@ import {
SvgPlus,
SvgWallet,
SvgFileText,
SvgOrganization,
} from "@opal/icons";
import { BillingInformation, LicenseStatus } from "@/lib/billing/interfaces";
import {
@@ -143,17 +144,20 @@ function SubscriptionCard({
license,
onViewPlans,
disabled,
isManualLicenseOnly,
onReconnect,
}: {
billing?: BillingInformation;
license?: LicenseStatus;
onViewPlans: () => void;
disabled?: boolean;
isManualLicenseOnly?: boolean;
onReconnect?: () => Promise<void>;
}) {
const [isReconnecting, setIsReconnecting] = useState(false);
const planName = "Business Plan";
const planName = isManualLicenseOnly ? "Enterprise Plan" : "Business Plan";
const PlanIcon = isManualLicenseOnly ? SvgOrganization : SvgUsers;
const expirationDate = billing?.current_period_end ?? license?.expires_at;
const formattedDate = formatDateShort(expirationDate);
@@ -211,7 +215,7 @@ function SubscriptionCard({
height="auto"
>
<Section gap={0.25} alignItems="start" height="auto" width="auto">
<SvgUsers className="w-5 h-5 stroke-text-03" />
<PlanIcon className="w-5 h-5" />
<Text headingH3Muted text04>
{planName}
</Text>
@@ -226,7 +230,19 @@ function SubscriptionCard({
height="auto"
width="fit"
>
{disabled ? (
{isManualLicenseOnly ? (
<Text secondaryBody text03 className="text-right">
Your plan is managed through sales.
<br />
<a
href="mailto:support@onyx.app?subject=Billing%20change%20request"
className="underline"
>
Contact billing
</a>{" "}
to make changes.
</Text>
) : disabled ? (
<Button
main
secondary
@@ -266,11 +282,13 @@ function SeatsCard({
license,
onRefresh,
disabled,
hideUpdateSeats,
}: {
billing?: BillingInformation;
license?: LicenseStatus;
onRefresh?: () => Promise<void>;
disabled?: boolean;
hideUpdateSeats?: boolean;
}) {
const [isEditing, setIsEditing] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
@@ -484,15 +502,17 @@ function SeatsCard({
<Button main tertiary href="/admin/users" leftIcon={SvgExternalLink}>
View Users
</Button>
<Button
main
secondary
onClick={handleStartEdit}
leftIcon={SvgPlus}
disabled={isLoadingUsers || disabled || !billing}
>
Update Seats
</Button>
{!hideUpdateSeats && (
<Button
main
secondary
onClick={handleStartEdit}
leftIcon={SvgPlus}
disabled={isLoadingUsers || disabled || !billing}
>
Update Seats
</Button>
)}
</Section>
</Section>
</Card>
@@ -593,7 +613,9 @@ interface BillingDetailsViewProps {
onViewPlans: () => void;
onRefresh?: () => Promise<void>;
isAirGapped?: boolean;
isManualLicenseOnly?: boolean;
hasStripeError?: boolean;
licenseCard?: React.ReactNode;
}
export default function BillingDetailsView({
@@ -602,10 +624,13 @@ export default function BillingDetailsView({
onViewPlans,
onRefresh,
isAirGapped,
isManualLicenseOnly,
hasStripeError,
licenseCard,
}: BillingDetailsViewProps) {
const expirationState = billing ? getExpirationState(billing, license) : null;
const disableBillingActions = isAirGapped || hasStripeError;
const disableBillingActions =
isAirGapped || hasStripeError || isManualLicenseOnly;
return (
<Section gap={1} height="auto" width="full">
@@ -622,7 +647,7 @@ export default function BillingDetailsView({
)}
{/* Air-gapped mode info banner */}
{isAirGapped && !hasStripeError && (
{isAirGapped && !hasStripeError && !isManualLicenseOnly && (
<Message
static
info
@@ -665,16 +690,21 @@ export default function BillingDetailsView({
license={license}
onViewPlans={onViewPlans}
disabled={disableBillingActions}
isManualLicenseOnly={isManualLicenseOnly}
onReconnect={onRefresh}
/>
)}
{/* License card (inline for manual license users) */}
{licenseCard}
{/* Seats card */}
<SeatsCard
billing={billing}
license={license}
onRefresh={onRefresh}
disabled={disableBillingActions}
hideUpdateSeats={isManualLicenseOnly}
/>
{/* Payment section */}

View File

@@ -19,6 +19,7 @@ interface LicenseActivationCardProps {
onClose: () => void;
onSuccess: () => void;
license?: LicenseStatus;
hideClose?: boolean;
}
export default function LicenseActivationCard({
@@ -26,6 +27,7 @@ export default function LicenseActivationCard({
onClose,
onSuccess,
license,
hideClose,
}: LicenseActivationCardProps) {
const [licenseKey, setLicenseKey] = useState("");
const [isActivating, setIsActivating] = useState(false);
@@ -120,9 +122,11 @@ export default function LicenseActivationCard({
<Button main secondary onClick={() => setShowInput(true)}>
Update Key
</Button>
<Button main tertiary onClick={handleClose}>
Close
</Button>
{!hideClose && (
<Button main tertiary onClick={handleClose}>
Close
</Button>
)}
</Section>
</Section>
</Card>

View File

@@ -121,11 +121,12 @@ export default function BillingPage() {
const billing = hasSubscription ? (billingData as BillingInformation) : null;
const isSelfHosted = !NEXT_PUBLIC_CLOUD_ENABLED;
// User is only air-gapped if they have a manual license AND Stripe is not connected
// Once Stripe connects successfully, they're no longer air-gapped
const hasManualLicense = licenseData?.source === "manual_upload";
const stripeConnected = billingData && !billingError;
const isAirGapped = hasManualLicense && !stripeConnected;
// Air-gapped: billing endpoint is unreachable (manual license + connectivity error)
const isAirGapped = !!(hasManualLicense && billingError);
// Stripe error: auto-fetched license but billing endpoint is unreachable
const hasStripeError = !!(
isSelfHosted &&
licenseData?.has_license &&
@@ -133,6 +134,10 @@ export default function BillingPage() {
!hasManualLicense
);
// Manual license without active Stripe subscription
// Stripe-dependent actions (manage plan, update seats) won't work
const isManualLicenseOnly = !!(hasManualLicense && !hasSubscription);
// Set initial view based on subscription status (only once when data first loads)
useEffect(() => {
if (!isLoading && view === null) {
@@ -243,7 +248,10 @@ export default function BillingPage() {
return {
icon: hasSubscription ? SvgWallet : SvgArrowUpCircle,
title: hasSubscription ? "View Plans" : "Upgrade Plan",
showBackButton: !!hasSubscription,
showBackButton: !!(
hasSubscription ||
(isSelfHosted && licenseData?.has_license)
),
};
case "details":
return {
@@ -271,9 +279,11 @@ export default function BillingPage() {
};
const handleBack = () => {
const hasEntitlement =
hasSubscription || (isSelfHosted && licenseData?.has_license);
if (view === "checkout") {
changeView(hasSubscription ? "details" : "plans");
} else if (view === "plans" && hasSubscription) {
changeView(hasEntitlement ? "details" : "plans");
} else if (view === "plans" && hasEntitlement) {
changeView("details");
}
};
@@ -305,7 +315,19 @@ export default function BillingPage() {
onViewPlans={() => changeView("plans")}
onRefresh={handleRefresh}
isAirGapped={isAirGapped}
isManualLicenseOnly={isManualLicenseOnly}
hasStripeError={hasStripeError}
licenseCard={
isManualLicenseOnly ? (
<LicenseActivationCard
isOpen
onSuccess={handleLicenseActivated}
license={licenseData ?? undefined}
onClose={() => {}}
hideClose
/>
) : undefined
}
/>
),
};
@@ -322,7 +344,7 @@ export default function BillingPage() {
if (isLoading || view === null) return null;
return (
<>
{showLicenseActivationInput && (
{showLicenseActivationInput && !isManualLicenseOnly && (
<div className="w-full billing-card-enter">
<LicenseActivationCard
isOpen={showLicenseActivationInput}
@@ -341,6 +363,7 @@ export default function BillingPage() {
isSelfHosted ? () => setShowLicenseActivationInput(true) : undefined
}
hideLicenseLink={
isManualLicenseOnly ||
showLicenseActivationInput ||
(view === "plans" &&
(!!hasSubscription || !!licenseData?.has_license))

View File

@@ -14,6 +14,11 @@ import {
TimelineRendererComponent,
TimelineRendererOutput,
} from "./TimelineRendererComponent";
import {
isReasoningPackets,
isDeepResearchPlanPackets,
isMemoryToolPackets,
} from "./packetHelpers";
import Tabs from "@/refresh-components/Tabs";
import { SvgBranch, SvgFold, SvgExpand } from "@opal/icons";
import { Button } from "@opal/components";
@@ -60,6 +65,13 @@ export function ParallelTimelineTabs({
[turnGroup.steps, activeTab]
);
// Determine if the active step needs full-width content (no right padding)
const noPaddingRight = activeStep
? isReasoningPackets(activeStep.packets) ||
isDeepResearchPlanPackets(activeStep.packets) ||
isMemoryToolPackets(activeStep.packets)
: false;
// Memoized loading states for each step
const loadingStates = useMemo(
() =>
@@ -82,9 +94,10 @@ export function ParallelTimelineTabs({
isFirstStep={false}
isSingleStep={false}
collapsible={true}
noPaddingRight={noPaddingRight}
/>
),
[isLastTurnGroup]
[isLastTurnGroup, noPaddingRight]
);
const hasActivePackets = Boolean(activeStep && activeStep.packets.length > 0);

View File

@@ -50,7 +50,7 @@ export function TimelineStepComposer({
header={result.status}
isExpanded={result.isExpanded}
onToggle={result.onToggle}
collapsible={collapsible}
collapsible={collapsible && !isSingleStep}
supportsCollapsible={result.supportsCollapsible}
isLastStep={index === results.length - 1 && isLastStep}
isFirstStep={index === 0 && isFirstStep}

View File

@@ -63,7 +63,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
return children([
{
icon: SvgCircle,
status: null,
status: "Reading",
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -46,7 +46,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
return children([
{
icon: SvgEditBig,
status: null,
status: "Memory",
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -169,7 +169,9 @@ export const ReasoningRenderer: MessageRenderer<
);
if (!hasStart && !hasEnd && content.length === 0) {
return children([{ icon: SvgCircle, status: null, content: <></> }]);
return children([
{ icon: SvgCircle, status: THINKING_STATUS, content: <></> },
]);
}
const reasoningContent = (

View File

@@ -61,7 +61,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
children,
}) => {
const searchState = constructCurrentSearchState(packets);
const { queries, results } = searchState;
const { queries, results, isComplete } = searchState;
const isCompact = renderType === RenderType.COMPACT;
const isHighlight = renderType === RenderType.HIGHLIGHT;
@@ -75,7 +75,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
return children([
{
icon: SvgSearchMenu,
status: null,
status: queriesHeader,
content: <></>,
supportsCollapsible: true,
timelineLayout: "timeline",
@@ -109,7 +109,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text04 mainUiMuted>
No results found
</Text>
)
}
/>
</div>
),
@@ -164,7 +172,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text04 mainUiMuted>
No results found
</Text>
)
}
/>
),
},
@@ -213,7 +229,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text03 mainUiMuted>
No results found
</Text>
)
}
/>
</>
)}

View File

@@ -53,7 +53,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
return children([
{
icon: SvgGlobe,
status: null,
status: "Searching the web",
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -64,6 +64,7 @@ function MemoryItem({
if (!shouldHighlight) return;
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
textareaRef.current?.focus();
setIsHighlighting(true);
const timer = setTimeout(() => {

View File

@@ -115,14 +115,9 @@ export default function ActionLineItem({
<Section gap={0.25} flexDirection="row">
{!isUnavailable && tool?.oauth_config_id && toolAuthStatus && (
<Button
icon={({ className }) => (
<SvgKey
className={cn(
className,
"stroke-yellow-500 hover:stroke-yellow-600"
)}
/>
)}
icon={SvgKey}
prominence="secondary"
size="sm"
onClick={noProp(() => {
if (
!toolAuthStatus.hasToken ||