mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 12:15:48 +00:00
Compare commits
7 Commits
llm_provid
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 |
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -294,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -488,22 +488,6 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -49,8 +50,11 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -118,8 +122,11 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -107,9 +113,8 @@ class OpenSearchClient:
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -125,6 +130,142 @@ class OpenSearchClient:
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -192,6 +333,38 @@ class OpenSearchClient:
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -610,43 +783,6 @@ class OpenSearchClient:
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -807,48 +943,6 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -93,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -248,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -258,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -269,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -279,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -471,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -492,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings.
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -508,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
|
||||
if not self._client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -620,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -660,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
return self._client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -760,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._os_client.update_document(
|
||||
self._client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -799,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._os_client.search(
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -849,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -881,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -909,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._os_client.bulk_index_documents(
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -22,10 +22,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -55,10 +52,8 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
@@ -238,9 +233,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -270,7 +268,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.model,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -305,7 +303,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
) -> list[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -330,15 +328,7 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return llm_provider_list
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -526,7 +516,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -555,13 +545,7 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return vision_provider_response
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -571,7 +555,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -608,15 +592,7 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
return accessible_providers
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -25,8 +21,6 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView")
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
@@ -58,17 +52,19 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -429,38 +425,3 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -32,6 +33,9 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -311,7 +315,14 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -45,9 +44,9 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> LLMProviderView:
|
||||
) -> None:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
return upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -103,11 +102,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -147,11 +152,17 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -183,9 +194,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -193,12 +202,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -232,9 +246,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -242,12 +254,17 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -280,7 +297,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
provider = upsert_llm_provider(
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -304,13 +321,18 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
model="gpt-4o-mini",
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -346,11 +368,17 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
model=model_name,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
|
||||
@@ -530,8 +530,14 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -540,7 +546,7 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
@@ -563,9 +569,14 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -605,13 +616,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
assert provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {db_provider.custom_config}"
|
||||
f"but got {provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -695,7 +706,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -708,10 +719,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
provider = put_llm_provider(
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
@@ -731,9 +742,14 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""External dependency unit tests for OpenSearchClient.
|
||||
"""External dependency unit tests for OpenSearchIndexClient.
|
||||
|
||||
These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for testing with automatic cleanup."""
|
||||
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
client = OpenSearchClient(index_name=test_index_name)
|
||||
client = OpenSearchIndexClient(index_name=test_index_name)
|
||||
|
||||
yield client # Test runs here.
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_client(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
|
||||
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
|
||||
"""Creates a search pipeline for testing with automatic cleanup."""
|
||||
test_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None
|
||||
|
||||
|
||||
class TestOpenSearchClient:
|
||||
"""Tests for OpenSearchClient."""
|
||||
"""Tests for OpenSearchIndexClient."""
|
||||
|
||||
def test_create_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index with a real schema."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
|
||||
# Verify index exists.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting an existing index returns True."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
|
||||
assert result is True
|
||||
assert test_client.validate_index(expected_mappings=mappings) is False
|
||||
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests deleting a nonexistent index returns False."""
|
||||
# Under test.
|
||||
# Don't create index, just try to delete.
|
||||
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_index_exists(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests checking if an index exists."""
|
||||
# Precondition.
|
||||
# Index should not exist before creation.
|
||||
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
|
||||
# Index should exist after creation.
|
||||
assert test_client.index_exists() is True
|
||||
|
||||
def test_validate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests validating an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -239,7 +239,120 @@ class TestOpenSearchClient:
|
||||
# Should return True after creation.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests put_mapping with same schema is idempotent."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Applying the same mappings again should succeed.
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Index should still be valid.
|
||||
assert test_client.validate_index(expected_mappings=mappings)
|
||||
|
||||
def test_put_mapping_adds_new_field(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping successfully adds new fields to existing index."""
|
||||
# Precondition.
|
||||
# Create index with minimal schema (just required fields).
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Add a new field using put_mapping.
|
||||
updated_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
# New field
|
||||
"new_test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
# Should not raise.
|
||||
test_client.put_mapping(updated_mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Validate the new schema includes the new field.
|
||||
assert test_client.validate_index(expected_mappings=updated_mappings)
|
||||
|
||||
def test_put_mapping_fails_on_type_change(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping fails when trying to change existing field type."""
|
||||
# Precondition.
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Try to change test_field type from keyword to text.
|
||||
conflicting_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "text"}, # Changed from keyword to text
|
||||
},
|
||||
}
|
||||
# Should raise because field type cannot be changed.
|
||||
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
|
||||
test_client.put_mapping(conflicting_mappings)
|
||||
|
||||
def test_put_mapping_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping on non-existent index raises an error."""
|
||||
# Precondition.
|
||||
# Index does not exist yet.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests creating an index twice raises an error."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -254,14 +367,14 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchClient) -> None:
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests creating and deleting a search pipeline."""
|
||||
# Under test and postcondition.
|
||||
@@ -278,7 +391,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
@@ -306,7 +419,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_bulk_index_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests bulk indexing documents."""
|
||||
# Precondition.
|
||||
@@ -337,7 +450,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
@@ -365,7 +478,7 @@ class TestOpenSearchClient:
|
||||
test_client.index_document(document=doc, tenant_state=tenant_state)
|
||||
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
@@ -401,7 +514,7 @@ class TestOpenSearchClient:
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -419,7 +532,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
@@ -455,7 +568,7 @@ class TestOpenSearchClient:
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
@@ -476,7 +589,7 @@ class TestOpenSearchClient:
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
@@ -552,7 +665,7 @@ class TestOpenSearchClient:
|
||||
assert len(keep_ids) == 1
|
||||
|
||||
def test_update_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a document's properties."""
|
||||
# Precondition.
|
||||
@@ -601,7 +714,7 @@ class TestOpenSearchClient:
|
||||
assert updated_doc.public == doc.public
|
||||
|
||||
def test_update_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -623,7 +736,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -704,7 +817,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_search_empty_index(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -743,7 +856,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -863,7 +976,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -993,7 +1106,7 @@ class TestOpenSearchClient:
|
||||
previous_score = current_score
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
@@ -1087,7 +1200,7 @@ class TestOpenSearchClient:
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
@@ -1116,7 +1229,7 @@ class TestOpenSearchClient:
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
@@ -1181,7 +1294,7 @@ class TestOpenSearchClient:
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
@@ -1259,7 +1372,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
test_client: OpenSearchIndexClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1352,7 +1465,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_random_search(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests the random search query works."""
|
||||
# Precondition.
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -74,7 +75,7 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
@@ -95,7 +96,7 @@ def _get_document_chunks_from_opensearch(
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
) -> None:
|
||||
opensearch_client.refresh_index()
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
@@ -283,10 +284,10 @@ def vespa_document_index(
|
||||
def opensearch_client(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
"""Creates an OpenSearch client for the test tenant."""
|
||||
active = get_active_search_settings(db_session)
|
||||
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
|
||||
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -330,7 +331,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
def test_documents(
|
||||
db_session: Session,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
patch_get_vespa_chunks_page_size: int,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""
|
||||
@@ -411,7 +412,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -480,7 +481,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -618,7 +619,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -712,7 +713,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchClient,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.oauth_config import create_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -491,3 +492,19 @@ class TestOAuthTokenManagerURLBuilding:
|
||||
# Should use & instead of ? since URL already has query params
|
||||
assert "foo=bar&" in url or "?foo=bar" in url
|
||||
assert "client_id=custom_client_id" in url
|
||||
|
||||
|
||||
class TestUnwrapSensitiveStr:
|
||||
"""Tests for _unwrap_sensitive_str static method"""
|
||||
|
||||
def test_unwrap_sensitive_str(self) -> None:
|
||||
"""Test that both SensitiveValue and plain str inputs are handled"""
|
||||
# SensitiveValue input
|
||||
sensitive = SensitiveValue[str](
|
||||
encrypted_bytes=b"test_client_id",
|
||||
decrypt_fn=lambda b: b.decode(),
|
||||
)
|
||||
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
|
||||
|
||||
# Plain str input
|
||||
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"
|
||||
|
||||
@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json()["providers"]:
|
||||
for provider in response.json():
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
@@ -578,57 +578,6 @@ def test_model_visibility_preserved_on_edit(reset: None) -> None: # noqa: ARG00
|
||||
assert visible_models[0]["name"] == "gpt-4o"
|
||||
|
||||
|
||||
def _get_provider_by_name(providers: list[dict], provider_name: str) -> dict | None:
|
||||
return next((p for p in providers if p["name"] == provider_name), None)
|
||||
|
||||
|
||||
def _get_providers_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> dict | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
resp_json = response.json()
|
||||
|
||||
return resp_json
|
||||
|
||||
|
||||
def _unpack_data(data: dict) -> tuple[list[dict], dict | None, dict | None]:
|
||||
providers = data["providers"]
|
||||
text_default = data.get("default_text")
|
||||
vision_default = data.get("default_vision")
|
||||
|
||||
return providers, text_default, vision_default
|
||||
|
||||
|
||||
def _get_providers_basic(
|
||||
user: DATestUser,
|
||||
) -> dict | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
resp_json = response.json()
|
||||
|
||||
return resp_json
|
||||
|
||||
|
||||
def _validate_default_model(
|
||||
default: dict | None,
|
||||
provider_id: int | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> None:
|
||||
if default is None:
|
||||
assert provider_id is None and model_name is None
|
||||
return
|
||||
|
||||
assert default["provider_id"] == provider_id
|
||||
assert default["model_name"] == model_name
|
||||
|
||||
|
||||
def _get_provider_by_name_admin(
|
||||
admin_user: DATestUser, provider_name: str
|
||||
) -> dict | None:
|
||||
@@ -649,7 +598,7 @@ def _get_provider_by_name_basic(user: DATestUser, provider_name: str) -> dict |
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
return next((p for p in providers if p["name"] == provider_name), None)
|
||||
|
||||
|
||||
@@ -833,22 +782,9 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
|
||||
# Capture initial defaults (setup_postgres may have created a DevEnvPresetOpenAI default)
|
||||
initial_data = _get_providers_admin(admin_user)
|
||||
assert initial_data is not None
|
||||
_, initial_text_default, initial_vision_default = _unpack_data(initial_data)
|
||||
|
||||
# Step 2: Verify via admin endpoint that all provider data is correct
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
# Defaults should be unchanged from initial state (new provider not set as default)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
expected_name=provider_name,
|
||||
@@ -861,13 +797,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 3: Verify via basic endpoint (admin user) that all provider data is correct
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
@@ -880,13 +810,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 4: Verify non-admin user sees the same provider data via basic endpoint
|
||||
basic_user_data = _get_providers_basic(basic_user)
|
||||
assert basic_user_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_user_data)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
|
||||
assert basic_user_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_user_provider_data,
|
||||
@@ -922,17 +846,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
assert default_provider_response.status_code == 200
|
||||
|
||||
# Step 6a: Verify the updated provider data via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
@@ -946,17 +860,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 6b: Verify the updated provider data via basic endpoint (admin user)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
@@ -969,17 +873,7 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 7: Verify non-admin user sees the updated provider data
|
||||
basic_user_data = _get_providers_basic(basic_user)
|
||||
assert basic_user_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_user_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
|
||||
assert basic_user_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_user_provider_data,
|
||||
@@ -999,7 +893,7 @@ def _get_all_providers_basic(user: DATestUser) -> list[dict]:
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["providers"]
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
|
||||
@@ -1009,7 +903,7 @@ def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["providers"]
|
||||
return response.json()
|
||||
|
||||
|
||||
def _set_default_provider(admin_user: DATestUser, provider_id: int) -> None:
|
||||
@@ -1145,17 +1039,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 3: Both admin and basic_user query and verify they see the same default
|
||||
# Validate via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_provider_data is not None
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
admin_default,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1166,7 +1054,9 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate provider 2 via admin endpoint (should not be default)
|
||||
admin_provider_2 = _get_provider_by_name(providers, provider_2_name)
|
||||
admin_provider_2 = next(
|
||||
(p for p in admin_providers if p["name"] == provider_2_name), None
|
||||
)
|
||||
assert admin_provider_2 is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_2,
|
||||
@@ -1180,17 +1070,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert basic_provider_data is not None
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_data,
|
||||
basic_default,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1200,17 +1084,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Also verify admin sees the same via basic endpoint
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
admin_basic_default,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1243,17 +1121,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 5: Both admin and basic_user verify they see the updated default
|
||||
# Validate via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_provider_data is not None
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
admin_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1264,7 +1136,9 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate provider 1 via admin endpoint (should no longer be default)
|
||||
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
admin_provider_1 = next(
|
||||
(p for p in admin_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
assert admin_provider_1 is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_1,
|
||||
@@ -1278,17 +1152,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert basic_provider_data is not None
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_data,
|
||||
basic_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1298,17 +1166,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (admin_user)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
admin_basic_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1338,17 +1200,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 7: Both users verify they see provider 2 as default with the shared model name
|
||||
# Validate via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_provider_data is not None
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
admin_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1359,17 +1215,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert basic_provider_data is not None
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_data,
|
||||
basic_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1379,17 +1229,11 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (admin_user)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
admin_basic_default,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1399,10 +1243,12 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Verify provider 1 is no longer the default and has correct data
|
||||
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_provider_1 is not None
|
||||
provider_1_admin = next(
|
||||
(p for p in admin_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
assert provider_1_admin is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_1,
|
||||
provider_1_admin,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1412,10 +1258,12 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
expected_is_public=True,
|
||||
)
|
||||
|
||||
basic_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
assert basic_provider_1 is not None
|
||||
provider_1_basic = next(
|
||||
(p for p in basic_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
assert provider_1_basic is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_1,
|
||||
provider_1_basic,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1543,22 +1391,10 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 5: Verify via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
|
||||
# Find and validate the default provider (provider 1)
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
admin_default = _get_provider_by_name(providers, provider_1_name)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
_validate_provider_data(
|
||||
admin_default,
|
||||
@@ -1573,7 +1409,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Find and validate the default vision provider (provider 2)
|
||||
admin_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
admin_vision_default = _find_default_vision_provider(admin_providers)
|
||||
assert admin_vision_default is not None
|
||||
_validate_provider_data(
|
||||
admin_vision_default,
|
||||
@@ -1589,21 +1425,10 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 6: Verify via basic endpoint (basic_user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
|
||||
# Find and validate the default provider (provider 1)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
basic_default = _get_provider_by_name(providers, provider_1_name)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
_validate_provider_data(
|
||||
basic_default,
|
||||
@@ -1617,7 +1442,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Find and validate the default vision provider (provider 2)
|
||||
basic_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
basic_vision_default = _find_default_vision_provider(basic_providers)
|
||||
assert basic_vision_default is not None
|
||||
_validate_provider_data(
|
||||
basic_vision_default,
|
||||
@@ -1632,20 +1457,9 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 7: Verify via basic endpoint (admin_user sees same as basic_user)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
admin_basic_default = _get_provider_by_name(providers, provider_1_name)
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_default,
|
||||
@@ -1658,7 +1472,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
expected_is_default_vision=None,
|
||||
)
|
||||
|
||||
admin_basic_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
admin_basic_vision_default = _find_default_vision_provider(admin_basic_providers)
|
||||
assert admin_basic_vision_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_vision_default,
|
||||
@@ -1735,14 +1549,7 @@ def test_default_provider_is_not_default_vision_provider(
|
||||
_set_default_provider(admin_user, created_provider["id"])
|
||||
|
||||
# Step 3 & 4: Verify via admin endpoint
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=created_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
|
||||
# Verify it IS the default provider
|
||||
@@ -1777,14 +1584,7 @@ def test_default_provider_is_not_default_vision_provider(
|
||||
)
|
||||
|
||||
# Also verify via basic endpoint
|
||||
basic_data = _get_providers_basic(admin_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=created_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
assert basic_provider_data is not None
|
||||
|
||||
assert (
|
||||
@@ -1969,52 +1769,37 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
|
||||
# Step 4: Verify all three types are correctly tracked
|
||||
|
||||
# Get all LLM providers (via admin endpoint)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=vision_provider["id"],
|
||||
model_name="gpt-4-vision-preview",
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default, vision_provider["id"], "gpt-4-vision-preview"
|
||||
)
|
||||
_get_provider_by_name(providers, regular_provider_name)
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
|
||||
# Get all image generation configs
|
||||
image_gen_configs = _get_all_image_gen_configs(admin_user)
|
||||
|
||||
# Verify the regular provider is the default provider
|
||||
admin_regular_provider_data = _get_provider_by_name(
|
||||
providers, regular_provider_name
|
||||
regular_provider_data = next(
|
||||
(p for p in admin_providers if p["name"] == regular_provider_name), None
|
||||
)
|
||||
assert admin_regular_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_regular_provider_data,
|
||||
expected_name=regular_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in regular_model_configs],
|
||||
expected_visible={c.name: True for c in regular_model_configs},
|
||||
expected_default_model="gpt-4",
|
||||
expected_is_default=True,
|
||||
)
|
||||
admin_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
|
||||
assert admin_vision_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_vision_provider_data,
|
||||
expected_name=vision_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in vision_model_configs],
|
||||
expected_visible={c.name: True for c in vision_model_configs},
|
||||
expected_default_model="gpt-4-vision-preview",
|
||||
expected_is_default=False,
|
||||
expected_default_vision_model="gpt-4-vision-preview",
|
||||
expected_is_default_vision=True,
|
||||
assert regular_provider_data is not None, "Regular provider not found"
|
||||
assert (
|
||||
regular_provider_data["is_default_provider"] is True
|
||||
), "Regular provider should be the default provider"
|
||||
assert (
|
||||
regular_provider_data.get("is_default_vision_provider") is not True
|
||||
), "Regular provider should NOT be the default vision provider"
|
||||
|
||||
# Verify the vision provider is the default vision provider
|
||||
vision_provider_data = next(
|
||||
(p for p in admin_providers if p["name"] == vision_provider_name), None
|
||||
)
|
||||
assert vision_provider_data is not None, "Vision provider not found"
|
||||
assert (
|
||||
vision_provider_data.get("is_default_provider") is not True
|
||||
), "Vision provider should NOT be the default provider"
|
||||
assert (
|
||||
vision_provider_data["is_default_vision_provider"] is True
|
||||
), "Vision provider should be the default vision provider"
|
||||
assert (
|
||||
vision_provider_data["default_vision_model"] == "gpt-4-vision-preview"
|
||||
), "Vision provider should have correct default vision model"
|
||||
|
||||
# Verify the image gen config is the default image generation config
|
||||
image_gen_config_data = next(
|
||||
@@ -2034,53 +1819,97 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
|
||||
), "Image gen config should have correct model name"
|
||||
|
||||
# Step 5: Verify no mixup - image gen providers don't appear in LLM provider lists
|
||||
# Image gen provider should not appear in the list
|
||||
assert image_gen_provider_id not in [p["name"] for p in providers]
|
||||
|
||||
# The image gen config creates an LLM provider with name "Image Gen - {image_provider_id}"
|
||||
# This should NOT be returned by the regular LLM provider endpoints
|
||||
[p["name"] for p in admin_providers]
|
||||
image_gen_llm_provider_name = f"Image Gen - {image_gen_provider_id}"
|
||||
|
||||
# Note: The image gen provider IS an LLM provider internally, so it may appear in the list
|
||||
# But it should NOT be marked as default provider or default vision provider
|
||||
image_gen_llm_provider = next(
|
||||
(p for p in admin_providers if p["name"] == image_gen_llm_provider_name), None
|
||||
)
|
||||
if image_gen_llm_provider:
|
||||
# If it appears, verify it's not marked as default for either type
|
||||
assert (
|
||||
image_gen_llm_provider.get("is_default_provider") is not True
|
||||
), "Image gen's internal LLM provider should NOT be the default provider"
|
||||
assert (
|
||||
image_gen_llm_provider.get("is_default_vision_provider") is not True
|
||||
), "Image gen's internal LLM provider should NOT be the default vision provider"
|
||||
|
||||
# Step 6: Verify via basic endpoint (non-admin user)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
|
||||
# Verify regular provider is default for basic user
|
||||
basic_regular = next(
|
||||
(p for p in basic_providers if p["name"] == regular_provider_name), None
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=vision_provider["id"],
|
||||
model_name="gpt-4-vision-preview",
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default, vision_provider["id"], "gpt-4-vision-preview"
|
||||
)
|
||||
basic_provider_data = _get_provider_by_name(providers, regular_provider_name)
|
||||
assert basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_data,
|
||||
expected_name=regular_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in regular_model_configs],
|
||||
expected_visible={c.name: True for c in regular_model_configs},
|
||||
expected_is_default=True,
|
||||
expected_default_model="gpt-4",
|
||||
)
|
||||
basic_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
|
||||
assert basic_vision_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_vision_provider_data,
|
||||
expected_name=vision_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in vision_model_configs],
|
||||
expected_visible={c.name: True for c in vision_model_configs},
|
||||
expected_is_default=False,
|
||||
expected_default_model="gpt-4-vision-preview",
|
||||
expected_default_vision_model="gpt-4-vision-preview",
|
||||
expected_is_default_vision=True,
|
||||
assert basic_regular is not None, "Regular provider not visible to basic user"
|
||||
assert (
|
||||
basic_regular["is_default_provider"] is True
|
||||
), "Regular provider should be default for basic user"
|
||||
|
||||
# Verify vision provider is default vision for basic user
|
||||
basic_vision = next(
|
||||
(p for p in basic_providers if p["name"] == vision_provider_name), None
|
||||
)
|
||||
assert basic_vision is not None, "Vision provider not visible to basic user"
|
||||
assert (
|
||||
basic_vision["is_default_vision_provider"] is True
|
||||
), "Vision provider should be default vision for basic user"
|
||||
|
||||
# Step 7: Verify the counts are as expected
|
||||
# We should have at least 2 user-created providers (setup_postgres may add more)
|
||||
assert len(providers) >= 2
|
||||
assert len(image_gen_configs) == 1
|
||||
# We should have at least 2 user-created providers plus the image gen internal provider
|
||||
user_created_providers = [
|
||||
p
|
||||
for p in admin_providers
|
||||
if p["name"] in [regular_provider_name, vision_provider_name]
|
||||
]
|
||||
assert (
|
||||
len(user_created_providers) == 2
|
||||
), f"Expected 2 user-created providers, got {len(user_created_providers)}"
|
||||
|
||||
# We should have exactly 1 image gen config
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
c
|
||||
for c in image_gen_configs
|
||||
if c["image_provider_id"] == image_gen_provider_id
|
||||
]
|
||||
)
|
||||
== 1
|
||||
), "Expected exactly 1 image gen config with our ID"
|
||||
|
||||
# Verify that our explicitly created providers are tracked correctly:
|
||||
# - Only ONE provider has is_default_provider=True
|
||||
default_providers = [
|
||||
p for p in admin_providers if p.get("is_default_provider") is True
|
||||
]
|
||||
assert (
|
||||
len(default_providers) == 1
|
||||
), f"Expected exactly 1 default provider, got {len(default_providers)}"
|
||||
assert default_providers[0]["name"] == regular_provider_name
|
||||
|
||||
# - Only ONE provider has is_default_vision_provider=True
|
||||
default_vision_providers = [
|
||||
p for p in admin_providers if p.get("is_default_vision_provider") is True
|
||||
]
|
||||
assert (
|
||||
len(default_vision_providers) == 1
|
||||
), f"Expected exactly 1 default vision provider, got {len(default_vision_providers)}"
|
||||
assert default_vision_providers[0]["name"] == vision_provider_name
|
||||
|
||||
# - Only ONE image gen config has is_default=True
|
||||
default_image_gen_configs = [
|
||||
c for c in image_gen_configs if c.get("is_default") is True
|
||||
]
|
||||
assert (
|
||||
len(default_image_gen_configs) == 1
|
||||
), f"Expected exactly 1 default image gen config, got {len(default_image_gen_configs)}"
|
||||
assert default_image_gen_configs[0]["image_provider_id"] == image_gen_provider_id
|
||||
|
||||
# Clean up: Delete the image gen config (to clean up the internal LLM provider)
|
||||
_delete_image_gen_config(admin_user, image_gen_provider_id)
|
||||
|
||||
@@ -272,19 +272,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
# First, clear any existing CHAT defaults (setup_postgres may have created one)
|
||||
existing_defaults = (
|
||||
db_session.query(LLMModelFlow)
|
||||
.filter(
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for existing in existing_defaults:
|
||||
existing.is_default = False
|
||||
db_session.flush()
|
||||
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
@@ -378,7 +365,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()["providers"]
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -393,7 +380,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
|
||||
@@ -12,7 +12,7 @@ const SvgOrganization = ({ size, ...props }: IconProps) => (
|
||||
>
|
||||
<path
|
||||
d="M7.5 14H13.5C14.0523 14 14.5 13.5523 14.5 13V6C14.5 5.44772 14.0523 5 13.5 5H7.5M7.5 14V11M7.5 14H4.5M7.5 5V3C7.5 2.44772 7.05228 2 6.5 2H4.5M7.5 5H1.5M7.5 5V8M1.5 5V3C1.5 2.44772 1.94772 2 2.5 2H4.5M1.5 5V8M7.5 8V11M7.5 8H4.5M1.5 8V11M1.5 8H4.5M7.5 11H4.5M1.5 11V13C1.5 13.5523 1.94772 14 2.5 14H4.5M1.5 11H4.5M4.5 2V8M4.5 14V11M4.5 11V8M10 8H12M10 11H12"
|
||||
strokeWidth={1}
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
SvgPlus,
|
||||
SvgWallet,
|
||||
SvgFileText,
|
||||
SvgOrganization,
|
||||
} from "@opal/icons";
|
||||
import { BillingInformation, LicenseStatus } from "@/lib/billing/interfaces";
|
||||
import {
|
||||
@@ -143,17 +144,20 @@ function SubscriptionCard({
|
||||
license,
|
||||
onViewPlans,
|
||||
disabled,
|
||||
isManualLicenseOnly,
|
||||
onReconnect,
|
||||
}: {
|
||||
billing?: BillingInformation;
|
||||
license?: LicenseStatus;
|
||||
onViewPlans: () => void;
|
||||
disabled?: boolean;
|
||||
isManualLicenseOnly?: boolean;
|
||||
onReconnect?: () => Promise<void>;
|
||||
}) {
|
||||
const [isReconnecting, setIsReconnecting] = useState(false);
|
||||
|
||||
const planName = "Business Plan";
|
||||
const planName = isManualLicenseOnly ? "Enterprise Plan" : "Business Plan";
|
||||
const PlanIcon = isManualLicenseOnly ? SvgOrganization : SvgUsers;
|
||||
const expirationDate = billing?.current_period_end ?? license?.expires_at;
|
||||
const formattedDate = formatDateShort(expirationDate);
|
||||
|
||||
@@ -211,7 +215,7 @@ function SubscriptionCard({
|
||||
height="auto"
|
||||
>
|
||||
<Section gap={0.25} alignItems="start" height="auto" width="auto">
|
||||
<SvgUsers className="w-5 h-5 stroke-text-03" />
|
||||
<PlanIcon className="w-5 h-5" />
|
||||
<Text headingH3Muted text04>
|
||||
{planName}
|
||||
</Text>
|
||||
@@ -226,7 +230,19 @@ function SubscriptionCard({
|
||||
height="auto"
|
||||
width="fit"
|
||||
>
|
||||
{disabled ? (
|
||||
{isManualLicenseOnly ? (
|
||||
<Text secondaryBody text03 className="text-right">
|
||||
Your plan is managed through sales.
|
||||
<br />
|
||||
<a
|
||||
href="mailto:support@onyx.app?subject=Billing%20change%20request"
|
||||
className="underline"
|
||||
>
|
||||
Contact billing
|
||||
</a>{" "}
|
||||
to make changes.
|
||||
</Text>
|
||||
) : disabled ? (
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
@@ -266,11 +282,13 @@ function SeatsCard({
|
||||
license,
|
||||
onRefresh,
|
||||
disabled,
|
||||
hideUpdateSeats,
|
||||
}: {
|
||||
billing?: BillingInformation;
|
||||
license?: LicenseStatus;
|
||||
onRefresh?: () => Promise<void>;
|
||||
disabled?: boolean;
|
||||
hideUpdateSeats?: boolean;
|
||||
}) {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
@@ -484,15 +502,17 @@ function SeatsCard({
|
||||
<Button main tertiary href="/admin/users" leftIcon={SvgExternalLink}>
|
||||
View Users
|
||||
</Button>
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
onClick={handleStartEdit}
|
||||
leftIcon={SvgPlus}
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
>
|
||||
Update Seats
|
||||
</Button>
|
||||
{!hideUpdateSeats && (
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
onClick={handleStartEdit}
|
||||
leftIcon={SvgPlus}
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
>
|
||||
Update Seats
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
</Section>
|
||||
</Card>
|
||||
@@ -593,7 +613,9 @@ interface BillingDetailsViewProps {
|
||||
onViewPlans: () => void;
|
||||
onRefresh?: () => Promise<void>;
|
||||
isAirGapped?: boolean;
|
||||
isManualLicenseOnly?: boolean;
|
||||
hasStripeError?: boolean;
|
||||
licenseCard?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function BillingDetailsView({
|
||||
@@ -602,10 +624,13 @@ export default function BillingDetailsView({
|
||||
onViewPlans,
|
||||
onRefresh,
|
||||
isAirGapped,
|
||||
isManualLicenseOnly,
|
||||
hasStripeError,
|
||||
licenseCard,
|
||||
}: BillingDetailsViewProps) {
|
||||
const expirationState = billing ? getExpirationState(billing, license) : null;
|
||||
const disableBillingActions = isAirGapped || hasStripeError;
|
||||
const disableBillingActions =
|
||||
isAirGapped || hasStripeError || isManualLicenseOnly;
|
||||
|
||||
return (
|
||||
<Section gap={1} height="auto" width="full">
|
||||
@@ -622,7 +647,7 @@ export default function BillingDetailsView({
|
||||
)}
|
||||
|
||||
{/* Air-gapped mode info banner */}
|
||||
{isAirGapped && !hasStripeError && (
|
||||
{isAirGapped && !hasStripeError && !isManualLicenseOnly && (
|
||||
<Message
|
||||
static
|
||||
info
|
||||
@@ -665,16 +690,21 @@ export default function BillingDetailsView({
|
||||
license={license}
|
||||
onViewPlans={onViewPlans}
|
||||
disabled={disableBillingActions}
|
||||
isManualLicenseOnly={isManualLicenseOnly}
|
||||
onReconnect={onRefresh}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* License card (inline for manual license users) */}
|
||||
{licenseCard}
|
||||
|
||||
{/* Seats card */}
|
||||
<SeatsCard
|
||||
billing={billing}
|
||||
license={license}
|
||||
onRefresh={onRefresh}
|
||||
disabled={disableBillingActions}
|
||||
hideUpdateSeats={isManualLicenseOnly}
|
||||
/>
|
||||
|
||||
{/* Payment section */}
|
||||
|
||||
@@ -19,6 +19,7 @@ interface LicenseActivationCardProps {
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
license?: LicenseStatus;
|
||||
hideClose?: boolean;
|
||||
}
|
||||
|
||||
export default function LicenseActivationCard({
|
||||
@@ -26,6 +27,7 @@ export default function LicenseActivationCard({
|
||||
onClose,
|
||||
onSuccess,
|
||||
license,
|
||||
hideClose,
|
||||
}: LicenseActivationCardProps) {
|
||||
const [licenseKey, setLicenseKey] = useState("");
|
||||
const [isActivating, setIsActivating] = useState(false);
|
||||
@@ -120,9 +122,11 @@ export default function LicenseActivationCard({
|
||||
<Button main secondary onClick={() => setShowInput(true)}>
|
||||
Update Key
|
||||
</Button>
|
||||
<Button main tertiary onClick={handleClose}>
|
||||
Close
|
||||
</Button>
|
||||
{!hideClose && (
|
||||
<Button main tertiary onClick={handleClose}>
|
||||
Close
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
</Section>
|
||||
</Card>
|
||||
|
||||
@@ -121,11 +121,12 @@ export default function BillingPage() {
|
||||
const billing = hasSubscription ? (billingData as BillingInformation) : null;
|
||||
const isSelfHosted = !NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
|
||||
// User is only air-gapped if they have a manual license AND Stripe is not connected
|
||||
// Once Stripe connects successfully, they're no longer air-gapped
|
||||
const hasManualLicense = licenseData?.source === "manual_upload";
|
||||
const stripeConnected = billingData && !billingError;
|
||||
const isAirGapped = hasManualLicense && !stripeConnected;
|
||||
|
||||
// Air-gapped: billing endpoint is unreachable (manual license + connectivity error)
|
||||
const isAirGapped = !!(hasManualLicense && billingError);
|
||||
|
||||
// Stripe error: auto-fetched license but billing endpoint is unreachable
|
||||
const hasStripeError = !!(
|
||||
isSelfHosted &&
|
||||
licenseData?.has_license &&
|
||||
@@ -133,6 +134,10 @@ export default function BillingPage() {
|
||||
!hasManualLicense
|
||||
);
|
||||
|
||||
// Manual license without active Stripe subscription
|
||||
// Stripe-dependent actions (manage plan, update seats) won't work
|
||||
const isManualLicenseOnly = !!(hasManualLicense && !hasSubscription);
|
||||
|
||||
// Set initial view based on subscription status (only once when data first loads)
|
||||
useEffect(() => {
|
||||
if (!isLoading && view === null) {
|
||||
@@ -243,7 +248,10 @@ export default function BillingPage() {
|
||||
return {
|
||||
icon: hasSubscription ? SvgWallet : SvgArrowUpCircle,
|
||||
title: hasSubscription ? "View Plans" : "Upgrade Plan",
|
||||
showBackButton: !!hasSubscription,
|
||||
showBackButton: !!(
|
||||
hasSubscription ||
|
||||
(isSelfHosted && licenseData?.has_license)
|
||||
),
|
||||
};
|
||||
case "details":
|
||||
return {
|
||||
@@ -271,9 +279,11 @@ export default function BillingPage() {
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
const hasEntitlement =
|
||||
hasSubscription || (isSelfHosted && licenseData?.has_license);
|
||||
if (view === "checkout") {
|
||||
changeView(hasSubscription ? "details" : "plans");
|
||||
} else if (view === "plans" && hasSubscription) {
|
||||
changeView(hasEntitlement ? "details" : "plans");
|
||||
} else if (view === "plans" && hasEntitlement) {
|
||||
changeView("details");
|
||||
}
|
||||
};
|
||||
@@ -305,7 +315,19 @@ export default function BillingPage() {
|
||||
onViewPlans={() => changeView("plans")}
|
||||
onRefresh={handleRefresh}
|
||||
isAirGapped={isAirGapped}
|
||||
isManualLicenseOnly={isManualLicenseOnly}
|
||||
hasStripeError={hasStripeError}
|
||||
licenseCard={
|
||||
isManualLicenseOnly ? (
|
||||
<LicenseActivationCard
|
||||
isOpen
|
||||
onSuccess={handleLicenseActivated}
|
||||
license={licenseData ?? undefined}
|
||||
onClose={() => {}}
|
||||
hideClose
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
),
|
||||
};
|
||||
@@ -322,7 +344,7 @@ export default function BillingPage() {
|
||||
if (isLoading || view === null) return null;
|
||||
return (
|
||||
<>
|
||||
{showLicenseActivationInput && (
|
||||
{showLicenseActivationInput && !isManualLicenseOnly && (
|
||||
<div className="w-full billing-card-enter">
|
||||
<LicenseActivationCard
|
||||
isOpen={showLicenseActivationInput}
|
||||
@@ -341,6 +363,7 @@ export default function BillingPage() {
|
||||
isSelfHosted ? () => setShowLicenseActivationInput(true) : undefined
|
||||
}
|
||||
hideLicenseLink={
|
||||
isManualLicenseOnly ||
|
||||
showLicenseActivationInput ||
|
||||
(view === "plans" &&
|
||||
(!!hasSubscription || !!licenseData?.has_license))
|
||||
|
||||
@@ -14,6 +14,11 @@ import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererOutput,
|
||||
} from "./TimelineRendererComponent";
|
||||
import {
|
||||
isReasoningPackets,
|
||||
isDeepResearchPlanPackets,
|
||||
isMemoryToolPackets,
|
||||
} from "./packetHelpers";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { SvgBranch, SvgFold, SvgExpand } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
@@ -60,6 +65,13 @@ export function ParallelTimelineTabs({
|
||||
[turnGroup.steps, activeTab]
|
||||
);
|
||||
|
||||
// Determine if the active step needs full-width content (no right padding)
|
||||
const noPaddingRight = activeStep
|
||||
? isReasoningPackets(activeStep.packets) ||
|
||||
isDeepResearchPlanPackets(activeStep.packets) ||
|
||||
isMemoryToolPackets(activeStep.packets)
|
||||
: false;
|
||||
|
||||
// Memoized loading states for each step
|
||||
const loadingStates = useMemo(
|
||||
() =>
|
||||
@@ -82,9 +94,10 @@ export function ParallelTimelineTabs({
|
||||
isFirstStep={false}
|
||||
isSingleStep={false}
|
||||
collapsible={true}
|
||||
noPaddingRight={noPaddingRight}
|
||||
/>
|
||||
),
|
||||
[isLastTurnGroup]
|
||||
[isLastTurnGroup, noPaddingRight]
|
||||
);
|
||||
|
||||
const hasActivePackets = Boolean(activeStep && activeStep.packets.length > 0);
|
||||
|
||||
@@ -50,7 +50,7 @@ export function TimelineStepComposer({
|
||||
header={result.status}
|
||||
isExpanded={result.isExpanded}
|
||||
onToggle={result.onToggle}
|
||||
collapsible={collapsible}
|
||||
collapsible={collapsible && !isSingleStep}
|
||||
supportsCollapsible={result.supportsCollapsible}
|
||||
isLastStep={index === results.length - 1 && isLastStep}
|
||||
isFirstStep={index === 0 && isFirstStep}
|
||||
|
||||
@@ -63,7 +63,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgCircle,
|
||||
status: null,
|
||||
status: "Reading",
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -46,7 +46,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgEditBig,
|
||||
status: null,
|
||||
status: "Memory",
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -169,7 +169,9 @@ export const ReasoningRenderer: MessageRenderer<
|
||||
);
|
||||
|
||||
if (!hasStart && !hasEnd && content.length === 0) {
|
||||
return children([{ icon: SvgCircle, status: null, content: <></> }]);
|
||||
return children([
|
||||
{ icon: SvgCircle, status: THINKING_STATUS, content: <></> },
|
||||
]);
|
||||
}
|
||||
|
||||
const reasoningContent = (
|
||||
|
||||
@@ -61,7 +61,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
children,
|
||||
}) => {
|
||||
const searchState = constructCurrentSearchState(packets);
|
||||
const { queries, results } = searchState;
|
||||
const { queries, results, isComplete } = searchState;
|
||||
|
||||
const isCompact = renderType === RenderType.COMPACT;
|
||||
const isHighlight = renderType === RenderType.HIGHLIGHT;
|
||||
@@ -75,7 +75,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
return children([
|
||||
{
|
||||
icon: SvgSearchMenu,
|
||||
status: null,
|
||||
status: queriesHeader,
|
||||
content: <></>,
|
||||
supportsCollapsible: true,
|
||||
timelineLayout: "timeline",
|
||||
@@ -109,7 +109,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text04 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
@@ -164,7 +172,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text04 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
/>
|
||||
),
|
||||
},
|
||||
@@ -213,7 +229,15 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text03 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -53,7 +53,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgGlobe,
|
||||
status: null,
|
||||
status: "Searching the web",
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -64,6 +64,7 @@ function MemoryItem({
|
||||
if (!shouldHighlight) return;
|
||||
|
||||
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
|
||||
textareaRef.current?.focus();
|
||||
setIsHighlighting(true);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
|
||||
@@ -115,14 +115,9 @@ export default function ActionLineItem({
|
||||
<Section gap={0.25} flexDirection="row">
|
||||
{!isUnavailable && tool?.oauth_config_id && toolAuthStatus && (
|
||||
<Button
|
||||
icon={({ className }) => (
|
||||
<SvgKey
|
||||
className={cn(
|
||||
className,
|
||||
"stroke-yellow-500 hover:stroke-yellow-600"
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
icon={SvgKey}
|
||||
prominence="secondary"
|
||||
size="sm"
|
||||
onClick={noProp(() => {
|
||||
if (
|
||||
!toolAuthStatus.hasToken ||
|
||||
|
||||
Reference in New Issue
Block a user