Compare commits

..

8 Commits

Author SHA1 Message Date
Dane Urban
3081be33d7 nits 2026-02-25 21:43:16 -08:00
Dane Urban
4822c89b34 . 2026-02-25 21:43:16 -08:00
Dane Urban
1c1e53bc06 . 2026-02-25 21:43:16 -08:00
Dane Urban
c22c71251a . 2026-02-25 21:43:16 -08:00
Dane Urban
ce7863bde7 Fix tests 2026-02-25 21:43:16 -08:00
Dane Urban
97d9e506b6 get code working 2026-02-25 21:43:16 -08:00
Dane Urban
44dca099da . 2026-02-25 21:38:52 -08:00
Dane Urban
18e75ca9ca update test llm 2026-02-25 21:38:52 -08:00
49 changed files with 883 additions and 1704 deletions

View File

@@ -9,8 +9,7 @@ inputs:
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: false
default: ""
required: true
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
@@ -92,6 +91,11 @@ runs:
max_attempts: 2
retry_wait_seconds: 10
command: |
if [ -z "${MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \

View File

@@ -1,6 +1,6 @@
name: Nightly LLM Provider Chat Tests
name: Nightly LLM Provider Chat Tests (OpenAI)
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
@@ -13,20 +13,19 @@ permissions:
contents: read
jobs:
provider-chat-test:
openai-provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
provider: openai
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]
needs: [openai-provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
@@ -40,6 +39,6 @@ jobs:
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
failed-jobs: openai-provider-chat-test
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -3,26 +3,33 @@ name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
provider:
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
required: true
type: string
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
models:
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
required: true
type: string
strict:
description: "Default NIGHTLY_LLM_STRICT passed to tests"
description: "Pass-through value for NIGHTLY_LLM_STRICT"
required: false
default: true
type: boolean
api_base:
description: "Optional NIGHTLY_LLM_API_BASE override"
required: false
default: ""
type: string
custom_config_json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
required: false
default: ""
type: string
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
provider_api_key:
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
required: true
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
@@ -31,8 +38,29 @@ on:
permissions:
contents: read
env:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
jobs:
validate-inputs:
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Validate required nightly provider inputs
run: |
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
build-backend-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -62,6 +90,7 @@ jobs:
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -90,6 +119,7 @@ jobs:
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
@@ -119,25 +149,11 @@ jobs:
provider-chat-test:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
[build-backend-image, build-model-server-image, build-integration-image]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
@@ -151,10 +167,12 @@ jobs:
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
models: ${{ env.NIGHTLY_LLM_MODELS }}
provider-api-key: ${{ secrets.provider_api_key }}
strict: ${{ env.NIGHTLY_LLM_STRICT }}
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
@@ -176,7 +194,7 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log

View File

@@ -58,27 +58,16 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data=data,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -126,26 +115,15 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data=data,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -163,13 +141,8 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"client_id": oauth_config.client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -188,12 +161,6 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -48,7 +48,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -150,12 +149,8 @@ def migrate_chunks_from_vespa_to_opensearch_task(
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,

View File

@@ -294,12 +294,6 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a

View File

@@ -488,6 +488,22 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:

View File

@@ -335,7 +335,6 @@ def update_persona_shared(
db_session: Session,
group_ids: list[int] | None = None,
is_public: bool | None = None,
label_ids: list[int] | None = None,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
@@ -345,7 +344,9 @@ def update_persona_shared(
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise PermissionError("You don't have permission to modify this persona")
raise HTTPException(
status_code=403, detail="You don't have permission to modify this persona"
)
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
@@ -359,15 +360,6 @@ def update_persona_shared(
group_ids=group_ids,
)
if label_ids is not None:
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
persona.labels.clear()
persona.labels = labels
db_session.commit()
@@ -973,8 +965,6 @@ def upsert_persona(
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
# Fetch and attach hierarchy_nodes by IDs
hierarchy_nodes = None

View File

@@ -11,7 +11,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from shared_configs.configs import MULTI_TENANT
@@ -50,11 +49,8 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
@@ -122,11 +118,8 @@ def get_all_document_indices(
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,

View File

@@ -1,7 +1,5 @@
import logging
import time
from contextlib import AbstractContextManager
from contextlib import nullcontext
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -85,26 +83,22 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
return new_body
class OpenSearchClient(AbstractContextManager):
"""Client for interacting with OpenSearch for cluster-level operations.
class OpenSearchClient:
"""Client for interacting with OpenSearch.
Args:
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
@@ -113,8 +107,9 @@ class OpenSearchClient(AbstractContextManager):
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
self._index_name = index_name
logger.debug(
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
)
self._client = OpenSearch(
hosts=[{"host": host, "port": port}],
@@ -130,142 +125,6 @@ class OpenSearchClient(AbstractContextManager):
# your request body that is less than this value.
timeout=timeout,
)
def __exit__(self, *_: Any) -> None:
self.close()
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
class OpenSearchIndexClient(OpenSearchClient):
"""Client for interacting with OpenSearch for index-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
index_name: The name of the index to interact with.
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
super().__init__(
host=host,
port=port,
auth=auth,
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_show_warn=ssl_show_warn,
timeout=timeout,
)
self._index_name = index_name
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -333,38 +192,6 @@ class OpenSearchIndexClient(OpenSearchClient):
"""
return self._client.indices.exists(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_mapping(self, mappings: dict[str, Any]) -> None:
"""Updates the index mapping in an idempotent manner.
- Existing fields with the same definition: No-op (succeeds silently).
- New fields: Added to the index.
- Existing fields with different types: Raises exception (requires
reindex).
See the OpenSearch documentation for more information:
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
Args:
mappings: The complete mapping definition to apply. This will be
merged with existing mappings in the index.
Raises:
Exception: There was an error updating the mappings, such as
attempting to change the type of an existing field.
"""
logger.debug(
f"Putting mappings for index {self._index_name} with mappings {mappings}."
)
response = self._client.indices.put_mapping(
index=self._index_name, body=mappings
)
if not response.get("acknowledged", False):
raise RuntimeError(
f"Failed to put the mapping update for index {self._index_name}."
)
logger.debug(f"Successfully put mappings for index {self._index_name}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
"""Validates the index.
@@ -783,6 +610,43 @@ class OpenSearchIndexClient(OpenSearchClient):
)
return DocumentChunk.model_validate(document_chunk_source)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
@@ -943,6 +807,48 @@ class OpenSearchIndexClient(OpenSearchClient):
"""
self._client.indices.refresh(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
TODO(andrei): Can we have some way to auto close when the client no
longer has any references?
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
def _get_hits_and_profile_from_search_result(
self, result: dict[str, Any]
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
@@ -1039,7 +945,14 @@ def wait_for_opensearch_with_timeout(
Returns:
True if OpenSearch is ready, False otherwise.
"""
with nullcontext(client) if client else OpenSearchClient() as client:
made_client = False
try:
if client is None:
# NOTE: index_name does not matter because we are only using this object
# to ping.
# TODO(andrei): Make this better.
client = OpenSearchClient(index_name="")
made_client = True
time_start = time.monotonic()
while True:
if client.ping():
@@ -1056,3 +969,7 @@ def wait_for_opensearch_with_timeout(
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
)
time.sleep(wait_interval_s)
finally:
if made_client:
assert client is not None
client.close()

View File

@@ -7,7 +7,6 @@ from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
@@ -41,7 +40,6 @@ from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
@@ -95,25 +93,6 @@ def generate_opensearch_filtered_access_control_list(
return list(access_control_list)
def set_cluster_state(client: OpenSearchClient) -> None:
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
logger.error(
"Failed to put cluster settings. If the settings have never been set before, "
"this may cause unexpected index creation when indexing documents into an "
"index that does not exist, or may cause expected logs to not appear. If this "
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -269,8 +248,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def __init__(
self,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
@@ -281,6 +258,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
index_name=index_name,
secondary_index_name=secondary_index_name,
)
if multitenant:
raise ValueError(
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
)
if multitenant != MULTI_TENANT:
raise ValueError(
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
@@ -288,10 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
)
tenant_id = get_current_tenant_id()
self._real_index = OpenSearchDocumentIndex(
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
)
@staticmethod
@@ -300,8 +279,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
# TODO(andrei): Implement.
raise NotImplementedError(
"Bug: Multitenant index registration is not supported for OpenSearch."
"Multitenant index registration is not yet implemented for OpenSearch."
)
def ensure_indices_exist(
@@ -491,37 +471,19 @@ class OpenSearchDocumentIndex(DocumentIndex):
for an OpenSearch search engine instance. It handles the complete lifecycle
of document chunks within a specific OpenSearch index/schema.
Each kind of embedding used should correspond to a different instance of
this class, and therefore a different index in OpenSearch.
If in a multitenant environment and
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
if necessary on initialization. This is because there is no logic which runs
on cluster restart which scans through all search settings over all tenants
and creates the relevant indices.
Args:
tenant_state: The tenant state of the caller.
index_name: The name of the index to interact with.
embedding_dim: The dimensionality of the embeddings used for the index.
embedding_precision: The precision of the embeddings used for the index.
Although not yet used in this way in the codebase, each kind of embedding
used should correspond to a different instance of this class, and therefore
a different index in OpenSearch.
"""
def __init__(
self,
tenant_state: TenantState,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
tenant_state: TenantState,
) -> None:
self._index_name: str = index_name
self._tenant_state: TenantState = tenant_state
self._client = OpenSearchIndexClient(index_name=self._index_name)
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
self.verify_and_create_index_if_necessary(
embedding_dim=embedding_dim, embedding_precision=embedding_precision
)
self._os_client = OpenSearchClient(index_name=self._index_name)
def verify_and_create_index_if_necessary(
self,
@@ -530,15 +492,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
) -> None:
"""Verifies and creates the index if necessary.
Also puts the desired cluster settings if not in a multitenant
environment.
Also puts the desired cluster settings.
Also puts the desired search pipeline state if not in a multitenant
environment, creating the pipelines if they do not exist and updating
them otherwise.
In a multitenant environment, the above steps happen explicitly on
setup.
Also puts the desired search pipeline state, creating the pipelines if
they do not exist and updating them otherwise.
Args:
embedding_dim: Vector dimensionality for the vector similarity part
@@ -551,38 +508,47 @@ class OpenSearchDocumentIndex(DocumentIndex):
search pipelines.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
f"necessary, with embedding dimension {embedding_dim}."
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
f"with embedding dimension {embedding_dim}."
)
if not self._tenant_state.multitenant:
set_cluster_state(self._client)
expected_mappings = DocumentSchema.get_document_schema(
embedding_dim, self._tenant_state.multitenant
)
if not self._client.index_exists():
if not self._os_client.put_cluster_settings(
settings=OPENSEARCH_CLUSTER_SETTINGS
):
logger.error(
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
"these settings have likely already been set. Not taking any further action..."
)
if not self._os_client.index_exists():
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
self._client.create_index(
self._os_client.create_index(
mappings=expected_mappings,
settings=index_settings,
)
else:
# Ensure schema is up to date by applying the current mappings.
try:
self._client.put_mapping(expected_mappings)
except Exception as e:
logger.error(
f"Failed to update mappings for index {self._index_name}. This likely means a "
f"field type was changed which requires reindexing. Error: {e}"
)
raise
if not self._os_client.validate_index(
expected_mappings=expected_mappings,
):
raise RuntimeError(
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
)
self._os_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
self._os_client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
def index(
self,
@@ -654,7 +620,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._client.bulk_index_documents(
self._os_client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
@@ -694,7 +660,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
tenant_state=self._tenant_state,
)
return self._client.delete_by_query(query_body)
return self._os_client.delete_by_query(query_body)
def update(
self,
@@ -794,7 +760,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
document_id=doc_id,
chunk_index=chunk_index,
)
self._client.update_document(
self._os_client.update_document(
document_chunk_id=document_chunk_id,
properties_to_update=properties_to_update,
)
@@ -833,7 +799,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
)
search_hits = self._client.search(
search_hits = self._os_client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -883,7 +849,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
@@ -915,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -943,6 +909,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
# Do not raise if the document already exists, just update. This is
# because the document may already have been indexed during the
# OpenSearch transition period.
self._client.bulk_index_documents(
self._os_client.bulk_index_documents(
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
)

View File

@@ -405,7 +405,6 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID] | None = None
group_ids: list[int] | None = None
is_public: bool | None = None
label_ids: list[int] | None = None
# We notify each user when a user is shared with them
@@ -416,22 +415,14 @@ def share_persona(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
label_ids=persona_share_request.label_ids,
)
except PermissionError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
logger.exception("Failed to share persona")
raise HTTPException(status_code=400, detail=str(e))
update_persona_shared(
persona_id=persona_id,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
)
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)

View File

@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -52,8 +55,10 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
@@ -233,12 +238,9 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -268,7 +270,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -303,7 +305,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -328,7 +330,15 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
@admin_router.put("/provider")
@@ -516,7 +526,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -545,7 +555,13 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return vision_provider_response
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
"""Endpoints for all"""
@@ -555,7 +571,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -592,7 +608,15 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
def get_valid_model_names_for_persona(

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -21,6 +25,8 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView")
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
@@ -52,19 +58,17 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
id: int | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -425,3 +429,38 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
@@ -33,9 +32,6 @@ from onyx.db.search_settings import update_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.key_value_store.factory import get_kv_store
@@ -315,14 +311,7 @@ def setup_multitenant_onyx() -> None:
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
return
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
opensearch_client = OpenSearchClient()
if not wait_for_opensearch_with_timeout(client=opensearch_client):
raise RuntimeError("Failed to connect to OpenSearch.")
set_cluster_state(opensearch_client)
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
# NOTE: Pretty sure this code is never hit in any production environment.
if not MANAGED_VESPA:
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -44,9 +45,9 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> None:
) -> LLMProviderView:
"""Helper to create a test LLM provider in the database."""
upsert_llm_provider(
return upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
@@ -102,17 +103,11 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,17 +147,11 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -194,7 +183,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -202,17 +193,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -246,7 +232,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -254,17 +242,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -297,7 +280,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
upsert_llm_provider(
provider = upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -321,18 +304,13 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name,
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -368,17 +346,11 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,

View File

@@ -530,14 +530,8 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -546,7 +540,7 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
@@ -569,14 +563,9 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
name=name,
id=provider.id,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=False,
),
@@ -616,13 +605,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
assert False, "Provider not found in the database"
assert provider.custom_config == custom_config, (
assert db_provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {provider.custom_config}"
f"but got {db_provider.custom_config}"
)
finally:
db_session.rollback()
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider = LlmProviderNames.VERTEX_AI.value
provider_name = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -719,10 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
provider=provider_name,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
@@ -742,14 +731,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
id=provider.id,
provider=provider_name,
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -1,4 +1,4 @@
"""External dependency unit tests for OpenSearchIndexClient.
"""External dependency unit tests for OpenSearchClient.
These tests assume OpenSearch is running and test all implemented methods
using real schemas, pipelines, and search queries from the codebase.
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.opensearch_document_index import (
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
@pytest.fixture(scope="function")
def test_client(
opensearch_available: None, # noqa: ARG001
) -> Generator[OpenSearchIndexClient, None, None]:
) -> Generator[OpenSearchClient, None, None]:
"""Creates an OpenSearch client for testing with automatic cleanup."""
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
client = OpenSearchIndexClient(index_name=test_index_name)
client = OpenSearchClient(index_name=test_index_name)
yield client # Test runs here.
@@ -142,7 +142,7 @@ def test_client(
@pytest.fixture(scope="function")
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
"""Creates a search pipeline for testing with automatic cleanup."""
test_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None,
class TestOpenSearchClient:
"""Tests for OpenSearchIndexClient."""
"""Tests for OpenSearchClient."""
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
def test_create_index(self, test_client: OpenSearchClient) -> None:
"""Tests creating an index with a real schema."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
# Verify index exists.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
"""Tests deleting an existing index returns True."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
assert result is True
assert test_client.validate_index(expected_mappings=mappings) is False
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
"""Tests deleting a nonexistent index returns False."""
# Under test.
# Don't create index, just try to delete.
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
# Postcondition.
assert result is False
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
def test_index_exists(self, test_client: OpenSearchClient) -> None:
"""Tests checking if an index exists."""
# Precondition.
# Index should not exist before creation.
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
# Index should exist after creation.
assert test_client.index_exists() is True
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
def test_validate_index(self, test_client: OpenSearchClient) -> None:
"""Tests validating an index."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -239,120 +239,7 @@ class TestOpenSearchClient:
# Should return True after creation.
assert test_client.validate_index(expected_mappings=mappings) is True
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
"""Tests put_mapping with same schema is idempotent."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Under test.
# Applying the same mappings again should succeed.
test_client.put_mapping(mappings)
# Postcondition.
# Index should still be valid.
assert test_client.validate_index(expected_mappings=mappings)
def test_put_mapping_adds_new_field(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping successfully adds new fields to existing index."""
# Precondition.
# Create index with minimal schema (just required fields).
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test.
# Add a new field using put_mapping.
updated_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"chunk_index": {"type": "integer"},
"content": {"type": "text"},
"content_vector": {
"type": "knn_vector",
"dimension": 128,
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": "lucene",
"parameters": {"ef_construction": 512, "m": 16},
},
},
# New field
"new_test_field": {"type": "keyword"},
},
}
# Should not raise.
test_client.put_mapping(updated_mappings)
# Postcondition.
# Validate the new schema includes the new field.
assert test_client.validate_index(expected_mappings=updated_mappings)
def test_put_mapping_fails_on_type_change(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping fails when trying to change existing field type."""
# Precondition.
initial_mappings = {
"dynamic": "strict",
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "keyword"},
},
}
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=initial_mappings, settings=settings)
# Under test and postcondition.
# Try to change test_field type from keyword to text.
conflicting_mappings = {
"properties": {
"document_id": {"type": "keyword"},
"test_field": {"type": "text"}, # Changed from keyword to text
},
}
# Should raise because field type cannot be changed.
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
test_client.put_mapping(conflicting_mappings)
def test_put_mapping_on_nonexistent_index(
self, test_client: OpenSearchIndexClient
) -> None:
"""Tests put_mapping on non-existent index raises an error."""
# Precondition.
# Index does not exist yet.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
)
# Under test and postcondition.
with pytest.raises(Exception, match="index_not_found_exception|404"):
test_client.put_mapping(mappings)
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
"""Tests creating an index twice raises an error."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
@@ -367,14 +254,14 @@ class TestOpenSearchClient:
with pytest.raises(Exception, match="already exists"):
test_client.create_index(mappings=mappings, settings=settings)
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
def test_update_settings(self, test_client: OpenSearchClient) -> None:
"""Tests that update_settings raises NotImplementedError."""
# Under test and postcondition.
with pytest.raises(NotImplementedError):
test_client.update_settings(settings={})
def test_create_and_delete_search_pipeline(
self, test_client: OpenSearchIndexClient
self, test_client: OpenSearchClient
) -> None:
"""Tests creating and deleting a search pipeline."""
# Under test and postcondition.
@@ -391,7 +278,7 @@ class TestOpenSearchClient:
)
def test_index_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a document."""
# Precondition.
@@ -419,7 +306,7 @@ class TestOpenSearchClient:
)
def test_bulk_index_documents(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests bulk indexing documents."""
# Precondition.
@@ -450,7 +337,7 @@ class TestOpenSearchClient:
)
def test_index_duplicate_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a duplicate document raises an error."""
# Precondition.
@@ -478,7 +365,7 @@ class TestOpenSearchClient:
test_client.index_document(document=doc, tenant_state=tenant_state)
def test_get_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a document."""
# Precondition.
@@ -514,7 +401,7 @@ class TestOpenSearchClient:
assert retrieved_doc == original_doc
def test_get_nonexistent_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a nonexistent document raises an error."""
# Precondition.
@@ -532,7 +419,7 @@ class TestOpenSearchClient:
)
def test_delete_existing_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting an existing document returns True."""
# Precondition.
@@ -568,7 +455,7 @@ class TestOpenSearchClient:
test_client.get_document(document_chunk_id=doc_chunk_id)
def test_delete_nonexistent_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting a nonexistent document returns False."""
# Precondition.
@@ -589,7 +476,7 @@ class TestOpenSearchClient:
assert result is False
def test_delete_by_query(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting documents by query."""
# Precondition.
@@ -665,7 +552,7 @@ class TestOpenSearchClient:
assert len(keep_ids) == 1
def test_update_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a document's properties."""
# Precondition.
@@ -714,7 +601,7 @@ class TestOpenSearchClient:
assert updated_doc.public == doc.public
def test_update_nonexistent_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests updating a nonexistent document raises an error."""
# Precondition.
@@ -736,7 +623,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline(
self,
test_client: OpenSearchIndexClient,
test_client: OpenSearchClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -817,7 +704,7 @@ class TestOpenSearchClient:
def test_search_empty_index(
self,
test_client: OpenSearchIndexClient,
test_client: OpenSearchClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -856,7 +743,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters(
self,
test_client: OpenSearchIndexClient,
test_client: OpenSearchClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -976,7 +863,7 @@ class TestOpenSearchClient:
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
self,
test_client: OpenSearchIndexClient,
test_client: OpenSearchClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -1106,7 +993,7 @@ class TestOpenSearchClient:
previous_score = current_score
def test_delete_by_query_multitenant_isolation(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
@@ -1200,7 +1087,7 @@ class TestOpenSearchClient:
assert set(remaining_y_ids) == expected_y_ids
def test_delete_by_query_nonexistent_document(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query for non-existent document returns 0 deleted.
@@ -1229,7 +1116,7 @@ class TestOpenSearchClient:
assert num_deleted == 0
def test_search_for_document_ids(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests search_for_document_ids method returns correct chunk IDs."""
# Precondition.
@@ -1294,7 +1181,7 @@ class TestOpenSearchClient:
assert set(chunk_ids) == expected_ids
def test_search_with_no_document_access_can_retrieve_all_documents(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests search with no document access can retrieve all documents, even
@@ -1372,7 +1259,7 @@ class TestOpenSearchClient:
def test_time_cutoff_filter(
self,
test_client: OpenSearchIndexClient,
test_client: OpenSearchClient,
search_pipeline: None, # noqa: ARG002
monkeypatch: pytest.MonkeyPatch,
) -> None:
@@ -1465,7 +1352,7 @@ class TestOpenSearchClient:
)
def test_random_search(
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests the random search query works."""
# Precondition.

View File

@@ -37,7 +37,6 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import DocumentChunk
@@ -75,7 +74,7 @@ CHUNK_COUNT = 5
def _get_document_chunks_from_opensearch(
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
) -> list[DocumentChunk]:
opensearch_client.refresh_index()
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
@@ -96,7 +95,7 @@ def _get_document_chunks_from_opensearch(
def _delete_document_chunks_from_opensearch(
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
) -> None:
opensearch_client.refresh_index()
query_body = DocumentQuery.delete_from_document_id_query(
@@ -284,10 +283,10 @@ def vespa_document_index(
def opensearch_client(
db_session: Session,
full_deployment_setup: None, # noqa: ARG001
) -> Generator[OpenSearchIndexClient, None, None]:
) -> Generator[OpenSearchClient, None, None]:
"""Creates an OpenSearch client for the test tenant."""
active = get_active_search_settings(db_session)
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
@pytest.fixture(scope="module")
@@ -331,7 +330,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
def test_documents(
db_session: Session,
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
opensearch_client: OpenSearchClient,
patch_get_vespa_chunks_page_size: int,
) -> Generator[list[Document], None, None]:
"""
@@ -412,7 +411,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
opensearch_client: OpenSearchClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -481,7 +480,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
opensearch_client: OpenSearchClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -619,7 +618,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
opensearch_client: OpenSearchClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
@@ -713,7 +712,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
opensearch_client: OpenSearchClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002

View File

@@ -20,7 +20,6 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
from onyx.db.models import OAuthConfig
from onyx.db.oauth_config import create_oauth_config
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.sensitive import SensitiveValue
from tests.external_dependency_unit.conftest import create_test_user
@@ -492,19 +491,3 @@ class TestOAuthTokenManagerURLBuilding:
# Should use & instead of ? since URL already has query params
assert "foo=bar&" in url or "?foo=bar" in url
assert "client_id=custom_client_id" in url
class TestUnwrapSensitiveStr:
"""Tests for _unwrap_sensitive_str static method"""
def test_unwrap_sensitive_str(self) -> None:
"""Test that both SensitiveValue and plain str inputs are handled"""
# SensitiveValue input
sensitive = SensitiveValue[str](
encrypted_bytes=b"test_client_id",
decrypt_fn=lambda b: b.decode(),
)
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
# Plain str input
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"

View File

@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json():
for provider in response.json()["providers"]:
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")

View File

@@ -23,7 +23,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
return next((p for p in providers if p["id"] == provider_id), None)
@@ -578,6 +578,57 @@ def test_model_visibility_preserved_on_edit(reset: None) -> None: # noqa: ARG00
assert visible_models[0]["name"] == "gpt-4o"
def _get_provider_by_name(providers: list[dict], provider_name: str) -> dict | None:
return next((p for p in providers if p["name"] == provider_name), None)
def _get_providers_admin(
admin_user: DATestUser,
) -> dict | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
resp_json = response.json()
return resp_json
def _unpack_data(data: dict) -> tuple[list[dict], dict | None, dict | None]:
providers = data["providers"]
text_default = data.get("default_text")
vision_default = data.get("default_vision")
return providers, text_default, vision_default
def _get_providers_basic(
user: DATestUser,
) -> dict | None:
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=user.headers,
)
assert response.status_code == 200
resp_json = response.json()
return resp_json
def _validate_default_model(
default: dict | None,
provider_id: int | None = None,
model_name: str | None = None,
) -> None:
if default is None:
assert provider_id is None and model_name is None
return
assert default["provider_id"] == provider_id
assert default["model_name"] == model_name
def _get_provider_by_name_admin(
admin_user: DATestUser, provider_name: str
) -> dict | None:
@@ -598,7 +649,7 @@ def _get_provider_by_name_basic(user: DATestUser, provider_name: str) -> dict |
headers=user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
return next((p for p in providers if p["name"] == provider_name), None)
@@ -782,9 +833,22 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
assert create_response.status_code == 200
# Capture initial defaults (setup_postgres may have created a DevEnvPresetOpenAI default)
initial_data = _get_providers_admin(admin_user)
assert initial_data is not None
_, initial_text_default, initial_vision_default = _unpack_data(initial_data)
# Step 2: Verify via admin endpoint that all provider data is correct
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
# Defaults should be unchanged from initial state (new provider not set as default)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
admin_provider_data = _get_provider_by_name(providers, provider_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_provider_data,
expected_name=provider_name,
@@ -797,7 +861,13 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 3: Verify via basic endpoint (admin user) that all provider data is correct
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_provider_data,
@@ -810,7 +880,13 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 4: Verify non-admin user sees the same provider data via basic endpoint
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
basic_user_data = _get_providers_basic(basic_user)
assert basic_user_data is not None
providers, text_default, vision_default = _unpack_data(basic_user_data)
assert text_default == initial_text_default
assert vision_default == initial_vision_default
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
assert basic_user_provider_data is not None
_validate_provider_data(
basic_user_provider_data,
@@ -846,7 +922,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
assert default_provider_response.status_code == 200
# Step 6a: Verify the updated provider data via admin endpoint
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_provider_data,
@@ -860,7 +946,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 6b: Verify the updated provider data via basic endpoint (admin user)
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_provider_data,
@@ -873,7 +969,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
)
# Step 7: Verify non-admin user sees the updated provider data
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
basic_user_data = _get_providers_basic(basic_user)
assert basic_user_data is not None
providers, text_default, vision_default = _unpack_data(basic_user_data)
_validate_default_model(
text_default,
provider_id=update_response.json()["id"],
model_name=updated_default_model,
)
_validate_default_model(vision_default) # None
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
assert basic_user_provider_data is not None
_validate_provider_data(
basic_user_provider_data,
@@ -893,7 +999,7 @@ def _get_all_providers_basic(user: DATestUser) -> list[dict]:
headers=user.headers,
)
assert response.status_code == 200
return response.json()
return response.json()["providers"]
def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
@@ -903,7 +1009,7 @@ def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
headers=admin_user.headers,
)
assert response.status_code == 200
return response.json()
return response.json()["providers"]
def _set_default_provider(admin_user: DATestUser, provider_id: int) -> None:
@@ -1039,11 +1145,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 3: Both admin and basic_user query and verify they see the same default
# Validate via admin endpoint
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_1_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_default,
admin_provider_data,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1054,9 +1166,7 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate provider 2 via admin endpoint (should not be default)
admin_provider_2 = next(
(p for p in admin_providers if p["name"] == provider_2_name), None
)
admin_provider_2 = _get_provider_by_name(providers, provider_2_name)
assert admin_provider_2 is not None
_validate_provider_data(
admin_provider_2,
@@ -1070,11 +1180,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_1_name)
assert basic_provider_data is not None
_validate_provider_data(
basic_default,
basic_provider_data,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1084,11 +1200,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Also verify admin sees the same via basic endpoint
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_1["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_1_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_default,
admin_basic_provider_data,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1121,11 +1243,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 5: Both admin and basic_user verify they see the updated default
# Validate via admin endpoint
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_default,
admin_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1136,9 +1264,7 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate provider 1 via admin endpoint (should no longer be default)
admin_provider_1 = next(
(p for p in admin_providers if p["name"] == provider_1_name), None
)
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
assert admin_provider_1 is not None
_validate_provider_data(
admin_provider_1,
@@ -1152,11 +1278,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert basic_provider_data is not None
_validate_provider_data(
basic_default,
basic_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1166,11 +1298,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (admin_user)
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_default,
admin_basic_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=provider_2_unique_model,
@@ -1200,11 +1338,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
# Step 7: Both users verify they see provider 2 as default with the shared model name
# Validate via admin endpoint
admin_providers = _get_all_providers_admin(admin_user)
admin_default = _find_default_provider(admin_providers)
assert admin_default is not None
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_provider_data is not None
_validate_provider_data(
admin_default,
admin_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1215,11 +1359,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (basic_user)
basic_providers = _get_all_providers_basic(basic_user)
basic_default = _find_default_provider(basic_providers)
assert basic_default is not None
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert basic_provider_data is not None
_validate_provider_data(
basic_default,
basic_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1229,11 +1379,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Validate via basic endpoint (admin_user)
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
assert admin_basic_default is not None
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default, provider_id=provider_2["id"], model_name=shared_model_name
)
_validate_default_model(vision_default) # None
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
assert admin_basic_provider_data is not None
_validate_provider_data(
admin_basic_default,
admin_basic_provider_data,
expected_name=provider_2_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1243,12 +1399,10 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
)
# Verify provider 1 is no longer the default and has correct data
provider_1_admin = next(
(p for p in admin_providers if p["name"] == provider_1_name), None
)
assert provider_1_admin is not None
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
assert admin_provider_1 is not None
_validate_provider_data(
provider_1_admin,
admin_provider_1,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1258,12 +1412,10 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
expected_is_public=True,
)
provider_1_basic = next(
(p for p in basic_providers if p["name"] == provider_1_name), None
)
assert provider_1_basic is not None
basic_provider_1 = _get_provider_by_name(providers, provider_1_name)
assert basic_provider_1 is not None
_validate_provider_data(
provider_1_basic,
basic_provider_1,
expected_name=provider_1_name,
expected_provider=LlmProviderNames.OPENAI,
expected_default_model=shared_model_name,
@@ -1391,10 +1543,22 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 5: Verify via admin endpoint
admin_providers = _get_all_providers_admin(admin_user)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
# Find and validate the default provider (provider 1)
admin_default = _find_default_provider(admin_providers)
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
admin_default = _get_provider_by_name(providers, provider_1_name)
assert admin_default is not None
_validate_provider_data(
admin_default,
@@ -1409,7 +1573,7 @@ def test_default_provider_and_vision_provider_selection(
)
# Find and validate the default vision provider (provider 2)
admin_vision_default = _find_default_vision_provider(admin_providers)
admin_vision_default = _get_provider_by_name(providers, provider_2_name)
assert admin_vision_default is not None
_validate_provider_data(
admin_vision_default,
@@ -1425,10 +1589,21 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 6: Verify via basic endpoint (basic_user)
basic_providers = _get_all_providers_basic(basic_user)
# Find and validate the default provider (provider 1)
basic_default = _find_default_provider(basic_providers)
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
basic_default = _get_provider_by_name(providers, provider_1_name)
assert basic_default is not None
_validate_provider_data(
basic_default,
@@ -1442,7 +1617,7 @@ def test_default_provider_and_vision_provider_selection(
)
# Find and validate the default vision provider (provider 2)
basic_vision_default = _find_default_vision_provider(basic_providers)
basic_vision_default = _get_provider_by_name(providers, provider_2_name)
assert basic_vision_default is not None
_validate_provider_data(
basic_vision_default,
@@ -1457,9 +1632,20 @@ def test_default_provider_and_vision_provider_selection(
)
# Step 7: Verify via basic endpoint (admin_user sees same as basic_user)
admin_basic_providers = _get_all_providers_basic(admin_user)
admin_basic_default = _find_default_provider(admin_basic_providers)
admin_basic_data = _get_providers_basic(admin_user)
assert admin_basic_data is not None
providers, text_default, vision_default = _unpack_data(admin_basic_data)
_validate_default_model(
text_default,
provider_id=provider_1["id"],
model_name=provider_1_non_vision_model,
)
_validate_default_model(
vision_default,
provider_id=provider_2["id"],
model_name=provider_2_vision_model_1,
)
admin_basic_default = _get_provider_by_name(providers, provider_1_name)
assert admin_basic_default is not None
_validate_provider_data(
admin_basic_default,
@@ -1472,7 +1658,7 @@ def test_default_provider_and_vision_provider_selection(
expected_is_default_vision=None,
)
admin_basic_vision_default = _find_default_vision_provider(admin_basic_providers)
admin_basic_vision_default = _get_provider_by_name(providers, provider_2_name)
assert admin_basic_vision_default is not None
_validate_provider_data(
admin_basic_vision_default,
@@ -1549,7 +1735,14 @@ def test_default_provider_is_not_default_vision_provider(
_set_default_provider(admin_user, created_provider["id"])
# Step 3 & 4: Verify via admin endpoint
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=created_provider["id"], model_name="gpt-4"
)
_validate_default_model(vision_default) # None
admin_provider_data = _get_provider_by_name(providers, provider_name)
assert admin_provider_data is not None
# Verify it IS the default provider
@@ -1584,7 +1777,14 @@ def test_default_provider_is_not_default_vision_provider(
)
# Also verify via basic endpoint
basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
basic_data = _get_providers_basic(admin_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=created_provider["id"], model_name="gpt-4"
)
_validate_default_model(vision_default) # None
basic_provider_data = _get_provider_by_name(providers, provider_name)
assert basic_provider_data is not None
assert (
@@ -1769,37 +1969,52 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
# Step 4: Verify all three types are correctly tracked
# Get all LLM providers (via admin endpoint)
admin_providers = _get_all_providers_admin(admin_user)
admin_data = _get_providers_admin(admin_user)
assert admin_data is not None
providers, text_default, vision_default = _unpack_data(admin_data)
_validate_default_model(
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
)
_validate_default_model(
vision_default,
provider_id=vision_provider["id"],
model_name="gpt-4-vision-preview",
)
_validate_default_model(
vision_default, vision_provider["id"], "gpt-4-vision-preview"
)
_get_provider_by_name(providers, regular_provider_name)
# Get all image generation configs
image_gen_configs = _get_all_image_gen_configs(admin_user)
# Verify the regular provider is the default provider
regular_provider_data = next(
(p for p in admin_providers if p["name"] == regular_provider_name), None
admin_regular_provider_data = _get_provider_by_name(
providers, regular_provider_name
)
assert regular_provider_data is not None, "Regular provider not found"
assert (
regular_provider_data["is_default_provider"] is True
), "Regular provider should be the default provider"
assert (
regular_provider_data.get("is_default_vision_provider") is not True
), "Regular provider should NOT be the default vision provider"
# Verify the vision provider is the default vision provider
vision_provider_data = next(
(p for p in admin_providers if p["name"] == vision_provider_name), None
assert admin_regular_provider_data is not None
_validate_provider_data(
admin_regular_provider_data,
expected_name=regular_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in regular_model_configs],
expected_visible={c.name: True for c in regular_model_configs},
expected_default_model="gpt-4",
expected_is_default=True,
)
admin_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
assert admin_vision_provider_data is not None
_validate_provider_data(
admin_vision_provider_data,
expected_name=vision_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in vision_model_configs],
expected_visible={c.name: True for c in vision_model_configs},
expected_default_model="gpt-4-vision-preview",
expected_is_default=False,
expected_default_vision_model="gpt-4-vision-preview",
expected_is_default_vision=True,
)
assert vision_provider_data is not None, "Vision provider not found"
assert (
vision_provider_data.get("is_default_provider") is not True
), "Vision provider should NOT be the default provider"
assert (
vision_provider_data["is_default_vision_provider"] is True
), "Vision provider should be the default vision provider"
assert (
vision_provider_data["default_vision_model"] == "gpt-4-vision-preview"
), "Vision provider should have correct default vision model"
# Verify the image gen config is the default image generation config
image_gen_config_data = next(
@@ -1819,97 +2034,53 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
), "Image gen config should have correct model name"
# Step 5: Verify no mixup - image gen providers don't appear in LLM provider lists
# The image gen config creates an LLM provider with name "Image Gen - {image_provider_id}"
# This should NOT be returned by the regular LLM provider endpoints
[p["name"] for p in admin_providers]
image_gen_llm_provider_name = f"Image Gen - {image_gen_provider_id}"
# Note: The image gen provider IS an LLM provider internally, so it may appear in the list
# But it should NOT be marked as default provider or default vision provider
image_gen_llm_provider = next(
(p for p in admin_providers if p["name"] == image_gen_llm_provider_name), None
)
if image_gen_llm_provider:
# If it appears, verify it's not marked as default for either type
assert (
image_gen_llm_provider.get("is_default_provider") is not True
), "Image gen's internal LLM provider should NOT be the default provider"
assert (
image_gen_llm_provider.get("is_default_vision_provider") is not True
), "Image gen's internal LLM provider should NOT be the default vision provider"
# Image gen provider should not appear in the list
assert image_gen_provider_id not in [p["name"] for p in providers]
# Step 6: Verify via basic endpoint (non-admin user)
basic_providers = _get_all_providers_basic(basic_user)
# Verify regular provider is default for basic user
basic_regular = next(
(p for p in basic_providers if p["name"] == regular_provider_name), None
basic_data = _get_providers_basic(basic_user)
assert basic_data is not None
providers, text_default, vision_default = _unpack_data(basic_data)
_validate_default_model(
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
)
assert basic_regular is not None, "Regular provider not visible to basic user"
assert (
basic_regular["is_default_provider"] is True
), "Regular provider should be default for basic user"
# Verify vision provider is default vision for basic user
basic_vision = next(
(p for p in basic_providers if p["name"] == vision_provider_name), None
_validate_default_model(
vision_default,
provider_id=vision_provider["id"],
model_name="gpt-4-vision-preview",
)
_validate_default_model(
vision_default, vision_provider["id"], "gpt-4-vision-preview"
)
basic_provider_data = _get_provider_by_name(providers, regular_provider_name)
assert basic_provider_data is not None
_validate_provider_data(
basic_provider_data,
expected_name=regular_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in regular_model_configs],
expected_visible={c.name: True for c in regular_model_configs},
expected_is_default=True,
expected_default_model="gpt-4",
)
basic_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
assert basic_vision_provider_data is not None
_validate_provider_data(
basic_vision_provider_data,
expected_name=vision_provider_name,
expected_provider=LlmProviderNames.OPENAI,
expected_model_names=[c.name for c in vision_model_configs],
expected_visible={c.name: True for c in vision_model_configs},
expected_is_default=False,
expected_default_model="gpt-4-vision-preview",
expected_default_vision_model="gpt-4-vision-preview",
expected_is_default_vision=True,
)
assert basic_vision is not None, "Vision provider not visible to basic user"
assert (
basic_vision["is_default_vision_provider"] is True
), "Vision provider should be default vision for basic user"
# Step 7: Verify the counts are as expected
# We should have at least 2 user-created providers plus the image gen internal provider
user_created_providers = [
p
for p in admin_providers
if p["name"] in [regular_provider_name, vision_provider_name]
]
assert (
len(user_created_providers) == 2
), f"Expected 2 user-created providers, got {len(user_created_providers)}"
# We should have exactly 1 image gen config
assert (
len(
[
c
for c in image_gen_configs
if c["image_provider_id"] == image_gen_provider_id
]
)
== 1
), "Expected exactly 1 image gen config with our ID"
# Verify that our explicitly created providers are tracked correctly:
# - Only ONE provider has is_default_provider=True
default_providers = [
p for p in admin_providers if p.get("is_default_provider") is True
]
assert (
len(default_providers) == 1
), f"Expected exactly 1 default provider, got {len(default_providers)}"
assert default_providers[0]["name"] == regular_provider_name
# - Only ONE provider has is_default_vision_provider=True
default_vision_providers = [
p for p in admin_providers if p.get("is_default_vision_provider") is True
]
assert (
len(default_vision_providers) == 1
), f"Expected exactly 1 default vision provider, got {len(default_vision_providers)}"
assert default_vision_providers[0]["name"] == vision_provider_name
# - Only ONE image gen config has is_default=True
default_image_gen_configs = [
c for c in image_gen_configs if c.get("is_default") is True
]
assert (
len(default_image_gen_configs) == 1
), f"Expected exactly 1 default image gen config, got {len(default_image_gen_configs)}"
assert default_image_gen_configs[0]["image_provider_id"] == image_gen_provider_id
# We should have at least 2 user-created providers (setup_postgres may add more)
assert len(providers) >= 2
assert len(image_gen_configs) == 1
# Clean up: Delete the image gen config (to clean up the internal LLM provider)
_delete_image_gen_config(admin_user, image_gen_provider_id)

View File

@@ -272,6 +272,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
# First, clear any existing CHAT defaults (setup_postgres may have created one)
existing_defaults = (
db_session.query(LLMModelFlow)
.filter(
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
LLMModelFlow.is_default == True, # noqa: E712
)
.all()
)
for existing in existing_defaults:
existing.is_default = False
db_session.flush()
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
@@ -365,7 +378,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -380,7 +393,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_providers = admin_response.json()["providers"]
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names

View File

@@ -45,25 +45,15 @@ def _env_true(env_var: str, default: bool = False) -> bool:
return value.strip().lower() in {"1", "true", "yes", "on"}
def _parse_models_env(env_var: str) -> list[str]:
raw_value = os.environ.get(env_var, "").strip()
if not raw_value:
return []
try:
parsed_json = json.loads(raw_value)
except json.JSONDecodeError:
parsed_json = None
if isinstance(parsed_json, list):
return [str(model).strip() for model in parsed_json if str(model).strip()]
return [part.strip() for part in raw_value.split(",") if part.strip()]
def _split_csv_env(env_var: str) -> list[str]:
return [
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
]
def _load_provider_config() -> NightlyProviderConfig:
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
model_names = _parse_models_env(_ENV_MODELS)
model_names = _split_csv_env(_ENV_MODELS)
api_key = os.environ.get(_ENV_API_KEY) or None
api_base = os.environ.get(_ENV_API_BASE) or None
strict = _env_true(_ENV_STRICT, default=False)
@@ -323,21 +313,10 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
_seed_connector_for_search_tool(admin_user)
search_tool_id = _get_internal_search_tool_id(admin_user)
failures: list[str] = []
for model_name in config.model_names:
try:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)
except BaseException as exc:
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
raise
failures.append(
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
)
if failures:
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)

View File

@@ -534,10 +534,9 @@ services:
required: false
# Below is needed for the `docker-out-of-docker` execution mode
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
user: root
volumes:
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
- /var/run/docker.sock:/var/run/docker.sock
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
# privileged: true

View File

@@ -1,233 +0,0 @@
import "@opal/core/hoverable/styles.css";
import React, { createContext, useContext, useState, useCallback } from "react";
import { cn } from "@opal/utils";
import type { WithoutStyles } from "@opal/types";
// ---------------------------------------------------------------------------
// Context-per-group registry
// ---------------------------------------------------------------------------
/**
* Lazily-created map of group names to React contexts.
*
* Each group gets its own `React.Context<boolean | null>` so that a
* `Hoverable.Item` only re-renders when its *own* group's hover state
* changes — not when any unrelated group changes.
*
* The default value is `null` (no provider found), which lets
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
* not hovered" and throw when `group` was explicitly specified.
*/
const contextMap = new Map<string, React.Context<boolean | null>>();
function getOrCreateContext(group: string): React.Context<boolean | null> {
let ctx = contextMap.get(group);
if (!ctx) {
ctx = createContext<boolean | null>(null);
ctx.displayName = `HoverableContext(${group})`;
contextMap.set(group, ctx);
}
return ctx;
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HoverableRootProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group: string;
}
type HoverableItemVariant = "opacity-on-hover";
interface HoverableItemProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
children: React.ReactNode;
group?: string;
variant?: HoverableItemVariant;
}
// ---------------------------------------------------------------------------
// HoverableRoot
// ---------------------------------------------------------------------------
/**
* Hover-tracking container for a named group.
*
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
* provides the hover state via a per-group React context.
*
* Nesting works because each `Hoverable.Root` creates a **new** context
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
* reads from the inner provider, not the outer `group="a"` provider.
*
* @example
* ```tsx
* <Hoverable.Root group="card">
* <Card>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
* ```
*/
function HoverableRoot({
group,
children,
onMouseEnter: consumerMouseEnter,
onMouseLeave: consumerMouseLeave,
...props
}: HoverableRootProps) {
const [hovered, setHovered] = useState(false);
const onMouseEnter = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(true);
consumerMouseEnter?.(e);
},
[consumerMouseEnter]
);
const onMouseLeave = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(false);
consumerMouseLeave?.(e);
},
[consumerMouseLeave]
);
const GroupContext = getOrCreateContext(group);
return (
<GroupContext.Provider value={hovered}>
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
{children}
</div>
</GroupContext.Provider>
);
}
// ---------------------------------------------------------------------------
// HoverableItem
// ---------------------------------------------------------------------------
/**
* An element whose visibility is controlled by hover state.
*
* **Local mode** (`group` omitted): the item handles hover on its own
* element via CSS `:hover`. This is the core abstraction.
*
* **Group mode** (`group` provided): visibility is driven by a matching
* `Hoverable.Root` ancestor's hover state via React context. If no
* matching Root is found, an error is thrown.
*
* Uses data-attributes for variant styling (see `styles.css`).
*
* @example
* ```tsx
* // Local mode — hover on the item itself
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
*
* // Group mode — hover on the Root reveals the item
* <Hoverable.Root group="card">
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Hoverable.Root>
* ```
*
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
*/
function HoverableItem({
group,
variant = "opacity-on-hover",
children,
...props
}: HoverableItemProps) {
const contextValue = useContext(
group ? getOrCreateContext(group) : NOOP_CONTEXT
);
if (group && contextValue === null) {
throw new Error(
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
);
}
const isLocal = group === undefined;
return (
<div
{...props}
className={cn("hoverable-item")}
data-hoverable-variant={variant}
data-hoverable-active={
isLocal ? undefined : contextValue ? "true" : undefined
}
data-hoverable-local={isLocal ? "true" : undefined}
>
{children}
</div>
);
}
/** Stable context used when no group is specified (local mode). */
const NOOP_CONTEXT = createContext<boolean | null>(null);
// ---------------------------------------------------------------------------
// Compound export
// ---------------------------------------------------------------------------
/**
* Hoverable compound component for hover-to-reveal patterns.
*
* Provides two sub-components:
*
* - `Hoverable.Root` — A container that tracks hover state for a named group
* and provides it via React context.
*
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
* applies local CSS `:hover` for the variant effect. When `group` is
* specified, it reads hover state from the nearest matching
* `Hoverable.Root` — and throws if no matching Root is found.
*
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
* so each group's items only respond to their own root's hover.
*
* @example
* ```tsx
* import { Hoverable } from "@opal/core";
*
* // Group mode — hovering the card reveals the trash icon
* <Hoverable.Root group="card">
* <Card>
* <span>Card content</span>
* <Hoverable.Item group="card" variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* </Card>
* </Hoverable.Root>
*
* // Local mode — hovering the item itself reveals it
* <Hoverable.Item variant="opacity-on-hover">
* <TrashIcon />
* </Hoverable.Item>
* ```
*/
const Hoverable = {
Root: HoverableRoot,
Item: HoverableItem,
};
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
};

View File

@@ -1,18 +0,0 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 200ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
opacity: 0;
}
/* Group mode — Root controls visibility via React context */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
opacity: 1;
}
/* Local mode — item handles its own :hover */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
opacity: 1;
}

View File

@@ -1,11 +1,3 @@
/* Hoverable */
export {
Hoverable,
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
} from "@opal/core/hoverable/components";
/* Interactive */
export {
Interactive,

View File

@@ -12,7 +12,7 @@ const SvgOrganization = ({ size, ...props }: IconProps) => (
>
<path
d="M7.5 14H13.5C14.0523 14 14.5 13.5523 14.5 13V6C14.5 5.44772 14.0523 5 13.5 5H7.5M7.5 14V11M7.5 14H4.5M7.5 5V3C7.5 2.44772 7.05228 2 6.5 2H4.5M7.5 5H1.5M7.5 5V8M1.5 5V3C1.5 2.44772 1.94772 2 2.5 2H4.5M1.5 5V8M7.5 8V11M7.5 8H4.5M1.5 8V11M1.5 8H4.5M7.5 11H4.5M1.5 11V13C1.5 13.5523 1.94772 14 2.5 14H4.5M1.5 11H4.5M4.5 2V8M4.5 14V11M4.5 11V8M10 8H12M10 11H12"
strokeWidth={1.5}
strokeWidth={1}
strokeLinecap="round"
strokeLinejoin="round"
/>

View File

@@ -17,7 +17,6 @@ import {
SvgPlus,
SvgWallet,
SvgFileText,
SvgOrganization,
} from "@opal/icons";
import { BillingInformation, LicenseStatus } from "@/lib/billing/interfaces";
import {
@@ -144,20 +143,17 @@ function SubscriptionCard({
license,
onViewPlans,
disabled,
isManualLicenseOnly,
onReconnect,
}: {
billing?: BillingInformation;
license?: LicenseStatus;
onViewPlans: () => void;
disabled?: boolean;
isManualLicenseOnly?: boolean;
onReconnect?: () => Promise<void>;
}) {
const [isReconnecting, setIsReconnecting] = useState(false);
const planName = isManualLicenseOnly ? "Enterprise Plan" : "Business Plan";
const PlanIcon = isManualLicenseOnly ? SvgOrganization : SvgUsers;
const planName = "Business Plan";
const expirationDate = billing?.current_period_end ?? license?.expires_at;
const formattedDate = formatDateShort(expirationDate);
@@ -215,7 +211,7 @@ function SubscriptionCard({
height="auto"
>
<Section gap={0.25} alignItems="start" height="auto" width="auto">
<PlanIcon className="w-5 h-5" />
<SvgUsers className="w-5 h-5 stroke-text-03" />
<Text headingH3Muted text04>
{planName}
</Text>
@@ -230,19 +226,7 @@ function SubscriptionCard({
height="auto"
width="fit"
>
{isManualLicenseOnly ? (
<Text secondaryBody text03 className="text-right">
Your plan is managed through sales.
<br />
<a
href="mailto:support@onyx.app?subject=Billing%20change%20request"
className="underline"
>
Contact billing
</a>{" "}
to make changes.
</Text>
) : disabled ? (
{disabled ? (
<Button
main
secondary
@@ -282,13 +266,11 @@ function SeatsCard({
license,
onRefresh,
disabled,
hideUpdateSeats,
}: {
billing?: BillingInformation;
license?: LicenseStatus;
onRefresh?: () => Promise<void>;
disabled?: boolean;
hideUpdateSeats?: boolean;
}) {
const [isEditing, setIsEditing] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
@@ -502,17 +484,15 @@ function SeatsCard({
<Button main tertiary href="/admin/users" leftIcon={SvgExternalLink}>
View Users
</Button>
{!hideUpdateSeats && (
<Button
main
secondary
onClick={handleStartEdit}
leftIcon={SvgPlus}
disabled={isLoadingUsers || disabled || !billing}
>
Update Seats
</Button>
)}
<Button
main
secondary
onClick={handleStartEdit}
leftIcon={SvgPlus}
disabled={isLoadingUsers || disabled || !billing}
>
Update Seats
</Button>
</Section>
</Section>
</Card>
@@ -613,9 +593,7 @@ interface BillingDetailsViewProps {
onViewPlans: () => void;
onRefresh?: () => Promise<void>;
isAirGapped?: boolean;
isManualLicenseOnly?: boolean;
hasStripeError?: boolean;
licenseCard?: React.ReactNode;
}
export default function BillingDetailsView({
@@ -624,13 +602,10 @@ export default function BillingDetailsView({
onViewPlans,
onRefresh,
isAirGapped,
isManualLicenseOnly,
hasStripeError,
licenseCard,
}: BillingDetailsViewProps) {
const expirationState = billing ? getExpirationState(billing, license) : null;
const disableBillingActions =
isAirGapped || hasStripeError || isManualLicenseOnly;
const disableBillingActions = isAirGapped || hasStripeError;
return (
<Section gap={1} height="auto" width="full">
@@ -647,7 +622,7 @@ export default function BillingDetailsView({
)}
{/* Air-gapped mode info banner */}
{isAirGapped && !hasStripeError && !isManualLicenseOnly && (
{isAirGapped && !hasStripeError && (
<Message
static
info
@@ -690,21 +665,16 @@ export default function BillingDetailsView({
license={license}
onViewPlans={onViewPlans}
disabled={disableBillingActions}
isManualLicenseOnly={isManualLicenseOnly}
onReconnect={onRefresh}
/>
)}
{/* License card (inline for manual license users) */}
{licenseCard}
{/* Seats card */}
<SeatsCard
billing={billing}
license={license}
onRefresh={onRefresh}
disabled={disableBillingActions}
hideUpdateSeats={isManualLicenseOnly}
/>
{/* Payment section */}

View File

@@ -19,7 +19,6 @@ interface LicenseActivationCardProps {
onClose: () => void;
onSuccess: () => void;
license?: LicenseStatus;
hideClose?: boolean;
}
export default function LicenseActivationCard({
@@ -27,7 +26,6 @@ export default function LicenseActivationCard({
onClose,
onSuccess,
license,
hideClose,
}: LicenseActivationCardProps) {
const [licenseKey, setLicenseKey] = useState("");
const [isActivating, setIsActivating] = useState(false);
@@ -122,11 +120,9 @@ export default function LicenseActivationCard({
<Button main secondary onClick={() => setShowInput(true)}>
Update Key
</Button>
{!hideClose && (
<Button main tertiary onClick={handleClose}>
Close
</Button>
)}
<Button main tertiary onClick={handleClose}>
Close
</Button>
</Section>
</Section>
</Card>

View File

@@ -121,12 +121,11 @@ export default function BillingPage() {
const billing = hasSubscription ? (billingData as BillingInformation) : null;
const isSelfHosted = !NEXT_PUBLIC_CLOUD_ENABLED;
// User is only air-gapped if they have a manual license AND Stripe is not connected
// Once Stripe connects successfully, they're no longer air-gapped
const hasManualLicense = licenseData?.source === "manual_upload";
// Air-gapped: billing endpoint is unreachable (manual license + connectivity error)
const isAirGapped = !!(hasManualLicense && billingError);
// Stripe error: auto-fetched license but billing endpoint is unreachable
const stripeConnected = billingData && !billingError;
const isAirGapped = hasManualLicense && !stripeConnected;
const hasStripeError = !!(
isSelfHosted &&
licenseData?.has_license &&
@@ -134,10 +133,6 @@ export default function BillingPage() {
!hasManualLicense
);
// Manual license without active Stripe subscription
// Stripe-dependent actions (manage plan, update seats) won't work
const isManualLicenseOnly = !!(hasManualLicense && !hasSubscription);
// Set initial view based on subscription status (only once when data first loads)
useEffect(() => {
if (!isLoading && view === null) {
@@ -248,10 +243,7 @@ export default function BillingPage() {
return {
icon: hasSubscription ? SvgWallet : SvgArrowUpCircle,
title: hasSubscription ? "View Plans" : "Upgrade Plan",
showBackButton: !!(
hasSubscription ||
(isSelfHosted && licenseData?.has_license)
),
showBackButton: !!hasSubscription,
};
case "details":
return {
@@ -279,11 +271,9 @@ export default function BillingPage() {
};
const handleBack = () => {
const hasEntitlement =
hasSubscription || (isSelfHosted && licenseData?.has_license);
if (view === "checkout") {
changeView(hasEntitlement ? "details" : "plans");
} else if (view === "plans" && hasEntitlement) {
changeView(hasSubscription ? "details" : "plans");
} else if (view === "plans" && hasSubscription) {
changeView("details");
}
};
@@ -315,19 +305,7 @@ export default function BillingPage() {
onViewPlans={() => changeView("plans")}
onRefresh={handleRefresh}
isAirGapped={isAirGapped}
isManualLicenseOnly={isManualLicenseOnly}
hasStripeError={hasStripeError}
licenseCard={
isManualLicenseOnly ? (
<LicenseActivationCard
isOpen
onSuccess={handleLicenseActivated}
license={licenseData ?? undefined}
onClose={() => {}}
hideClose
/>
) : undefined
}
/>
),
};
@@ -344,7 +322,7 @@ export default function BillingPage() {
if (isLoading || view === null) return null;
return (
<>
{showLicenseActivationInput && !isManualLicenseOnly && (
{showLicenseActivationInput && (
<div className="w-full billing-card-enter">
<LicenseActivationCard
isOpen={showLicenseActivationInput}
@@ -363,7 +341,6 @@ export default function BillingPage() {
isSelfHosted ? () => setShowLicenseActivationInput(true) : undefined
}
hideLicenseLink={
isManualLicenseOnly ||
showLicenseActivationInput ||
(view === "plans" &&
(!!hasSubscription || !!licenseData?.has_license))

View File

@@ -14,11 +14,6 @@ import {
TimelineRendererComponent,
TimelineRendererOutput,
} from "./TimelineRendererComponent";
import {
isReasoningPackets,
isDeepResearchPlanPackets,
isMemoryToolPackets,
} from "./packetHelpers";
import Tabs from "@/refresh-components/Tabs";
import { SvgBranch, SvgFold, SvgExpand } from "@opal/icons";
import { Button } from "@opal/components";
@@ -65,13 +60,6 @@ export function ParallelTimelineTabs({
[turnGroup.steps, activeTab]
);
// Determine if the active step needs full-width content (no right padding)
const noPaddingRight = activeStep
? isReasoningPackets(activeStep.packets) ||
isDeepResearchPlanPackets(activeStep.packets) ||
isMemoryToolPackets(activeStep.packets)
: false;
// Memoized loading states for each step
const loadingStates = useMemo(
() =>
@@ -94,10 +82,9 @@ export function ParallelTimelineTabs({
isFirstStep={false}
isSingleStep={false}
collapsible={true}
noPaddingRight={noPaddingRight}
/>
),
[isLastTurnGroup, noPaddingRight]
[isLastTurnGroup]
);
const hasActivePackets = Boolean(activeStep && activeStep.packets.length > 0);

View File

@@ -50,7 +50,7 @@ export function TimelineStepComposer({
header={result.status}
isExpanded={result.isExpanded}
onToggle={result.onToggle}
collapsible={collapsible && !isSingleStep}
collapsible={collapsible}
supportsCollapsible={result.supportsCollapsible}
isLastStep={index === results.length - 1 && isLastStep}
isFirstStep={index === 0 && isFirstStep}

View File

@@ -63,7 +63,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
return children([
{
icon: SvgCircle,
status: "Reading",
status: null,
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -46,7 +46,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
return children([
{
icon: SvgEditBig,
status: "Memory",
status: null,
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -169,9 +169,7 @@ export const ReasoningRenderer: MessageRenderer<
);
if (!hasStart && !hasEnd && content.length === 0) {
return children([
{ icon: SvgCircle, status: THINKING_STATUS, content: <></> },
]);
return children([{ icon: SvgCircle, status: null, content: <></> }]);
}
const reasoningContent = (

View File

@@ -61,7 +61,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
children,
}) => {
const searchState = constructCurrentSearchState(packets);
const { queries, results, isComplete } = searchState;
const { queries, results } = searchState;
const isCompact = renderType === RenderType.COMPACT;
const isHighlight = renderType === RenderType.HIGHLIGHT;
@@ -75,7 +75,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
return children([
{
icon: SvgSearchMenu,
status: queriesHeader,
status: null,
content: <></>,
supportsCollapsible: true,
timelineLayout: "timeline",
@@ -109,15 +109,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text04 mainUiMuted>
No results found
</Text>
)
}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
/>
</div>
),
@@ -172,15 +164,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text04 mainUiMuted>
No results found
</Text>
)
}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
/>
),
},
@@ -229,15 +213,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
window.open(doc.link, "_blank", "noopener,noreferrer");
}
}}
emptyState={
!isComplete ? (
<BlinkingBar />
) : (
<Text as="p" text03 mainUiMuted>
No results found
</Text>
)
}
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
/>
</>
)}

View File

@@ -53,7 +53,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
return children([
{
icon: SvgGlobe,
status: "Searching the web",
status: null,
content: <div />,
supportsCollapsible: false,
timelineLayout: "timeline",

View File

@@ -131,8 +131,7 @@ export async function updateAgentSharedStatus(
userIds: string[],
groupIds: number[],
isPublic: boolean | undefined,
isPaidEnterpriseFeaturesEnabled: boolean,
labelIds?: number[]
isPaidEnterpriseFeaturesEnabled: boolean
): Promise<null | string> {
// MIT versions should not send group_ids - warn if caller provided non-empty groups
if (!isPaidEnterpriseFeaturesEnabled && groupIds.length > 0) {
@@ -153,7 +152,6 @@ export async function updateAgentSharedStatus(
// Only include group_ids for enterprise versions
group_ids: isPaidEnterpriseFeaturesEnabled ? groupIds : undefined,
is_public: isPublic,
label_ids: labelIds,
}),
});
@@ -168,63 +166,3 @@ export async function updateAgentSharedStatus(
return "Network error. Please check your connection and try again.";
}
}
/**
* Updates the labels assigned to an agent via the share endpoint.
*
* @param agentId - The ID of the agent to update
* @param labelIds - Array of label IDs to assign to the agent
* @returns null on success, or an error message string on failure
*/
export async function updateAgentLabels(
agentId: number,
labelIds: number[]
): Promise<string | null> {
try {
const response = await fetch(`/api/persona/${agentId}/share`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ label_ids: labelIds }),
});
if (response.ok) {
return null;
}
const errorMessage = (await response.json()).detail || "Unknown error";
return errorMessage;
} catch (error) {
console.error("updateAgentLabels: Network error", error);
return "Network error. Please check your connection and try again.";
}
}
/**
* Updates the featured (default) status of an agent.
*
* @param agentId - The ID of the agent to update
* @param isFeatured - Whether the agent should be featured
* @returns null on success, or an error message string on failure
*/
export async function updateAgentFeaturedStatus(
agentId: number,
isFeatured: boolean
): Promise<string | null> {
try {
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ is_default_persona: isFeatured }),
});
if (response.ok) {
return null;
}
const errorMessage = (await response.json()).detail || "Unknown error";
return errorMessage;
} catch (error) {
console.error("updateAgentFeaturedStatus: Network error", error);
return "Network error. Please check your connection and try again.";
}
}

View File

@@ -257,27 +257,19 @@ export const useLabels = () => {
return mutate("/api/persona/labels");
};
const createLabel = async (name: string): Promise<PersonaLabel | null> => {
const createLabel = async (name: string) => {
const response = await fetch("/api/persona/labels", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ name }),
});
if (!response.ok) {
return null;
if (response.ok) {
const newLabel = await response.json();
mutate("/api/persona/labels", [...(labels || []), newLabel], false);
}
const newLabel: PersonaLabel = await response.json();
mutate(
"/api/persona/labels",
(currentLabels: PersonaLabel[] | undefined) => [
...(currentLabels || []),
newLabel,
],
false
);
return newLabel;
return response;
};
const updateLabel = async (id: number, name: string) => {

View File

@@ -1,51 +1,29 @@
import Text from "@/refresh-components/texts/Text";
import { SvgX } from "@opal/icons";
import { Button } from "@opal/components";
import type { IconProps } from "@opal/types";
export interface ChipProps {
children?: string;
icon?: React.FunctionComponent<IconProps>;
onRemove?: () => void;
smallLabel?: boolean;
}
/**
* A simple chip/tag component for displaying metadata.
* Supports an optional remove button via the `onRemove` prop.
*
* @example
* ```tsx
* <Chip>Tag Name</Chip>
* <Chip icon={SvgUser}>John Doe</Chip>
* <Chip onRemove={() => removeTag(id)}>Removable</Chip>
* ```
*/
export default function Chip({
children,
icon: Icon,
onRemove,
smallLabel = true,
}: ChipProps) {
export default function Chip({ children, icon: Icon }: ChipProps) {
return (
<div className="flex items-center gap-1 px-1.5 py-0.5 rounded-08 bg-background-tint-02">
{Icon && <Icon size={12} className="text-text-03" />}
{children && (
<Text figureSmallLabel={smallLabel} text03>
<Text figureSmallLabel text03>
{children}
</Text>
)}
{onRemove && (
<Button
onClick={(e) => {
e.stopPropagation();
onRemove();
}}
prominence="tertiary"
icon={SvgX}
size="xs"
/>
)}
</div>
);
}

View File

@@ -1,125 +0,0 @@
"use client";
import * as React from "react";
import { cn } from "@/lib/utils";
import Chip from "@/refresh-components/Chip";
import {
innerClasses,
textClasses,
Variants,
wrapperClasses,
} from "@/refresh-components/inputs/styles";
import type { IconProps } from "@opal/types";
export interface ChipItem {
id: string;
label: string;
}
export interface InputChipFieldProps {
chips: ChipItem[];
onRemoveChip: (id: string) => void;
onAdd: (value: string) => void;
value: string;
onChange: (value: string) => void;
placeholder?: string;
disabled?: boolean;
variant?: Variants;
icon?: React.FunctionComponent<IconProps>;
className?: string;
}
/**
* A tag/chip input field that renders chips inline alongside a text input.
*
* Pressing Enter adds a chip via `onAdd`. Pressing Backspace on an empty
* input removes the last chip. Each chip has a remove button.
*
* @example
* ```tsx
* <InputChipField
* chips={[{ id: "1", label: "Search" }]}
* onRemoveChip={(id) => remove(id)}
* onAdd={(value) => add(value)}
* value={inputValue}
* onChange={setInputValue}
* placeholder="Add labels..."
* icon={SvgTag}
* />
* ```
*/
function InputChipField({
chips,
onRemoveChip,
onAdd,
value,
onChange,
placeholder,
disabled = false,
variant = "primary",
icon: Icon,
className,
}: InputChipFieldProps) {
const inputRef = React.useRef<HTMLInputElement>(null);
function handleKeyDown(e: React.KeyboardEvent<HTMLInputElement>) {
if (disabled) {
return;
}
if (e.key === "Enter") {
e.preventDefault();
e.stopPropagation();
const trimmed = value.trim();
if (trimmed) {
onAdd(trimmed);
}
}
if (e.key === "Backspace" && value === "") {
const lastChip = chips[chips.length - 1];
if (lastChip) {
onRemoveChip(lastChip.id);
}
}
}
return (
<div
className={cn(
"flex flex-row items-center flex-wrap gap-1 p-1.5 rounded-08 cursor-text w-full",
wrapperClasses[variant],
className
)}
onClick={() => inputRef.current?.focus()}
>
{Icon && <Icon size={16} className="text-text-04 shrink-0" />}
{chips.map((chip) => (
<Chip
key={chip.id}
onRemove={disabled ? undefined : () => onRemoveChip(chip.id)}
smallLabel={false}
>
{chip.label}
</Chip>
))}
<input
ref={inputRef}
type="text"
disabled={disabled}
value={value}
onChange={(e) => onChange(e.target.value)}
onKeyDown={handleKeyDown}
placeholder={chips.length === 0 ? placeholder : undefined}
className={cn(
"flex-1 min-w-[80px] h-[1.5rem] bg-transparent p-0.5 focus:outline-none",
innerClasses[variant],
textClasses[variant]
)}
/>
</div>
);
}
export default InputChipField;

View File

@@ -64,7 +64,6 @@ function MemoryItem({
if (!shouldHighlight) return;
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
textareaRef.current?.focus();
setIsHighlighting(true);
const timer = setTimeout(() => {

View File

@@ -115,9 +115,14 @@ export default function ActionLineItem({
<Section gap={0.25} flexDirection="row">
{!isUnavailable && tool?.oauth_config_id && toolAuthStatus && (
<Button
icon={SvgKey}
prominence="secondary"
size="sm"
icon={({ className }) => (
<SvgKey
className={cn(
className,
"stroke-yellow-500 hover:stroke-yellow-600"
)}
/>
)}
onClick={noProp(() => {
if (
!toolAuthStatus.hasToken ||

View File

@@ -651,8 +651,6 @@ export default function AgentEditorPage({
shared_user_ids: existingAgent?.users?.map((user) => user.id) ?? [],
shared_group_ids: existingAgent?.groups ?? [],
is_public: existingAgent?.is_public ?? true,
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
is_default_persona: existingAgent?.is_default_persona ?? false,
};
const validationSchema = Yup.object().shape({
@@ -814,8 +812,8 @@ export default function AgentEditorPage({
uploaded_image_id: values.uploaded_image_id,
icon_name: values.icon_name,
search_start_date: values.knowledge_cutoff_date || null,
label_ids: values.label_ids,
is_default_persona: values.is_default_persona,
label_ids: null,
is_default_persona: false,
// display_priority: ...,
user_file_ids: values.enable_knowledge ? values.user_file_ids : [],
@@ -1056,20 +1054,10 @@ export default function AgentEditorPage({
userIds={values.shared_user_ids}
groupIds={values.shared_group_ids}
isPublic={values.is_public}
isFeatured={values.is_default_persona}
labelIds={values.label_ids}
onShare={(
userIds,
groupIds,
isPublic,
isFeatured,
labelIds
) => {
onShare={(userIds, groupIds, isPublic) => {
setFieldValue("shared_user_ids", userIds);
setFieldValue("shared_group_ids", groupIds);
setFieldValue("is_public", isPublic);
setFieldValue("is_default_persona", isFeatured);
setFieldValue("label_ids", labelIds);
shareAgentModal.toggle(false);
}}
/>

View File

@@ -11,11 +11,7 @@ import { cn, noProp } from "@/lib/utils";
import { useRouter } from "next/navigation";
import type { Route } from "next";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import {
checkUserOwnsAssistant,
updateAgentSharedStatus,
updateAgentFeaturedStatus,
} from "@/lib/agents";
import { checkUserOwnsAssistant, updateAgentSharedStatus } from "@/lib/agents";
import { useUser } from "@/providers/UserProvider";
import {
SvgActions,
@@ -47,9 +43,8 @@ export default function AgentCard({ agent }: AgentCardProps) {
() => pinnedAgents.some((pinnedAgent) => pinnedAgent.id === agent.id),
[agent.id, pinnedAgents]
);
const { user, isAdmin, isCurator } = useUser();
const { user } = useUser();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const canUpdateFeaturedStatus = isAdmin || isCurator;
const isOwnedByUser = checkUserOwnsAssistant(user, agent);
const shareAgentModal = useCreateModal();
const agentViewerModal = useCreateModal();
@@ -63,49 +58,26 @@ export default function AgentCard({ agent }: AgentCardProps) {
route({ agentId: agent.id });
}, [pinned, togglePinnedAgent, agent, route]);
// Handle sharing agent
const handleShare = useCallback(
async (
userIds: string[],
groupIds: number[],
isPublic: boolean,
isFeatured: boolean,
labelIds: number[]
) => {
const shareError = await updateAgentSharedStatus(
async (userIds: string[], groupIds: number[], isPublic: boolean) => {
const error = await updateAgentSharedStatus(
agent.id,
userIds,
groupIds,
isPublic,
isPaidEnterpriseFeaturesEnabled,
labelIds
isPaidEnterpriseFeaturesEnabled
);
if (shareError) {
toast.error(`Failed to share agent: ${shareError}`);
return;
if (error) {
toast.error(`Failed to share agent: ${error}`);
} else {
// Revalidate the agent data to reflect the changes
refreshAgent();
shareAgentModal.toggle(false);
}
if (canUpdateFeaturedStatus) {
const featuredError = await updateAgentFeaturedStatus(
agent.id,
isFeatured
);
if (featuredError) {
toast.error(`Failed to update featured status: ${featuredError}`);
refreshAgent();
return;
}
}
refreshAgent();
shareAgentModal.toggle(false);
},
[
agent.id,
canUpdateFeaturedStatus,
isPaidEnterpriseFeaturesEnabled,
refreshAgent,
]
[agent.id, isPaidEnterpriseFeaturesEnabled, refreshAgent]
);
return (
@@ -116,8 +88,6 @@ export default function AgentCard({ agent }: AgentCardProps) {
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
groupIds={fullAgent?.groups ?? []}
isPublic={fullAgent?.is_public ?? false}
isFeatured={fullAgent?.is_default_persona ?? false}
labelIds={fullAgent?.labels?.map((l) => l.id) ?? []}
onShare={handleShare}
/>
</shareAgentModal.Provider>

View File

@@ -1,17 +1,15 @@
"use client";
import { useCallback, useMemo, useState } from "react";
import { useMemo } from "react";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import {
SvgLink,
SvgOrganization,
SvgShare,
SvgTag,
SvgUsers,
SvgX,
} from "@opal/icons";
import InputChipField from "@/refresh-components/inputs/InputChipField";
import Tabs from "@/refresh-components/Tabs";
import { Card } from "@/refresh-components/cards";
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
@@ -28,8 +26,6 @@ import { useUser } from "@/providers/UserProvider";
import { Formik, useFormikContext } from "formik";
import { useAgent } from "@/hooks/useAgents";
import { Button as OpalButton } from "@opal/components";
import { useLabels } from "@/lib/hooks";
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
const YOUR_ORGANIZATION_TAB = "Your Organization";
const USERS_AND_GROUPS_TAB = "Users & Groups";
@@ -42,8 +38,6 @@ interface ShareAgentFormValues {
selectedUserIds: string[];
selectedGroupIds: number[];
isPublic: boolean;
isFeatured: boolean;
labelIds: number[];
}
// ============================================================================
@@ -59,15 +53,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
useFormikContext<ShareAgentFormValues>();
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
const { data: groupsData } = useShareableGroups();
const { user: currentUser, isAdmin, isCurator } = useUser();
const { user: currentUser } = useUser();
const { agent: fullAgent } = useAgent(agentId ?? null);
const shareAgentModal = useModal();
const { labels: allLabels, createLabel } = useLabels();
const [labelInputValue, setLabelInputValue] = useState("");
const acceptedUsers = usersData ?? [];
const groups = groupsData ?? [];
const canUpdateFeaturedStatus = isAdmin || isCurator;
// Create options for InputComboBox from all accepted users and groups
const comboBoxOptions = useMemo(() => {
@@ -146,50 +137,6 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
);
}
const selectedLabels: PersonaLabel[] = useMemo(() => {
if (!allLabels) return [];
return allLabels.filter((label) => values.labelIds.includes(label.id));
}, [allLabels, values.labelIds]);
function handleRemoveLabel(labelId: number) {
setFieldValue(
"labelIds",
values.labelIds.filter((id) => id !== labelId)
);
}
const addLabel = useCallback(
async (name: string) => {
const trimmed = name.trim();
if (!trimmed) return;
const existing = allLabels?.find(
(l) => l.name.toLowerCase() === trimmed.toLowerCase()
);
if (existing) {
if (!values.labelIds.includes(existing.id)) {
setFieldValue("labelIds", [...values.labelIds, existing.id]);
}
} else {
const newLabel = await createLabel(trimmed);
if (newLabel) {
setFieldValue("labelIds", [...values.labelIds, newLabel.id]);
}
}
setLabelInputValue("");
},
[allLabels, values.labelIds, setFieldValue, createLabel]
);
const chipItems = useMemo(
() =>
selectedLabels.map((label) => ({
id: String(label.id),
label: label.name,
})),
[selectedLabels]
);
return (
<Modal.Content width="sm" height="lg">
<Modal.Header icon={SvgShare} title="Share Agent" onClose={handleClose} />
@@ -286,41 +233,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
</Tabs.Content>
<Tabs.Content value={YOUR_ORGANIZATION_TAB} padding={0.5}>
<Section gap={1} alignItems="stretch">
<InputLayouts.Horizontal
title="Publish This Agent"
description="Make this agent available to everyone in your organization."
>
<SwitchField name="isPublic" />
</InputLayouts.Horizontal>
{canUpdateFeaturedStatus && (
<>
<div className="border-t border-border-02" />
<InputLayouts.Horizontal
title="Feature This Agent"
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
>
<SwitchField name="isFeatured" />
</InputLayouts.Horizontal>
</>
)}
<InputChipField
chips={chipItems}
onRemoveChip={(id) => handleRemoveLabel(Number(id))}
onAdd={addLabel}
value={labelInputValue}
onChange={setLabelInputValue}
placeholder="Add labels..."
icon={SvgTag}
/>
<Text secondaryBody text04>
Add labels and categories to help people better discover this
agent.
</Text>
</Section>
<InputLayouts.Horizontal
title="Publish This Agent"
description="Make this agent available to everyone in your organization."
>
<SwitchField name="isPublic" />
</InputLayouts.Horizontal>
</Tabs.Content>
</Tabs>
</Card>
@@ -360,15 +278,7 @@ export interface ShareAgentModalProps {
userIds: string[];
groupIds: number[];
isPublic: boolean;
isFeatured: boolean;
labelIds: number[];
onShare?: (
userIds: string[],
groupIds: number[],
isPublic: boolean,
isFeatured: boolean,
labelIds: number[]
) => void;
onShare?: (userIds: string[], groupIds: number[], isPublic: boolean) => void;
}
export default function ShareAgentModal({
@@ -376,8 +286,6 @@ export default function ShareAgentModal({
userIds,
groupIds,
isPublic,
isFeatured,
labelIds,
onShare,
}: ShareAgentModalProps) {
const shareAgentModal = useModal();
@@ -386,18 +294,10 @@ export default function ShareAgentModal({
selectedUserIds: userIds,
selectedGroupIds: groupIds,
isPublic: isPublic,
isFeatured: isFeatured,
labelIds: labelIds,
};
function handleSubmit(values: ShareAgentFormValues) {
onShare?.(
values.selectedUserIds,
values.selectedGroupIds,
values.isPublic,
values.isFeatured,
values.labelIds
);
onShare?.(values.selectedUserIds, values.selectedGroupIds, values.isPublic);
}
return (