mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 20:25:46 +00:00
Compare commits
8 Commits
main
...
llm_provid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3081be33d7 | ||
|
|
4822c89b34 | ||
|
|
1c1e53bc06 | ||
|
|
c22c71251a | ||
|
|
ce7863bde7 | ||
|
|
97d9e506b6 | ||
|
|
44dca099da | ||
|
|
18e75ca9ca |
@@ -9,8 +9,7 @@ inputs:
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: false
|
||||
default: ""
|
||||
required: true
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
@@ -92,6 +91,11 @@ runs:
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
@@ -13,20 +13,19 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
@@ -40,6 +39,6 @@ jobs:
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -3,26 +3,33 @@ name: Reusable Nightly LLM Provider Chat Tests
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
type: string
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
type: string
|
||||
strict:
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
@@ -31,8 +38,29 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -62,6 +90,7 @@ jobs:
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -90,6 +119,7 @@ jobs:
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
@@ -119,25 +149,11 @@ jobs:
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_secret: openai_api_key
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_secret: anthropic_api_key
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -151,10 +167,12 @@ jobs:
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
@@ -176,7 +194,7 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -58,27 +58,16 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data=data,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -126,26 +115,15 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data=data,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -163,13 +141,8 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"client_id": oauth_config.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -188,12 +161,6 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -48,7 +48,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -150,12 +149,8 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
|
||||
@@ -294,12 +294,6 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -488,6 +488,22 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
|
||||
@@ -335,7 +335,6 @@ def update_persona_shared(
|
||||
db_session: Session,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -345,7 +344,9 @@ def update_persona_shared(
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise PermissionError("You don't have permission to modify this persona")
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
@@ -359,15 +360,6 @@ def update_persona_shared(
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
if label_ids is not None:
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
persona.labels.clear()
|
||||
persona.labels = labels
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -973,8 +965,6 @@ def upsert_persona(
|
||||
labels = (
|
||||
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
|
||||
)
|
||||
if len(labels) != len(label_ids):
|
||||
raise ValueError("Some label IDs were not found in the database")
|
||||
|
||||
# Fetch and attach hierarchy_nodes by IDs
|
||||
hierarchy_nodes = None
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
@@ -50,11 +49,8 @@ def get_default_document_index(
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
@@ -122,11 +118,8 @@ def get_all_document_indices(
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -85,26 +83,22 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient(AbstractContextManager):
|
||||
"""Client for interacting with OpenSearch for cluster-level operations.
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
|
||||
Args:
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
@@ -113,8 +107,9 @@ class OpenSearchClient(AbstractContextManager):
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -130,142 +125,6 @@ class OpenSearchClient(AbstractContextManager):
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __exit__(self, *_: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
|
||||
class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""Client for interacting with OpenSearch for index-level operations.
|
||||
|
||||
OpenSearch's Python module has pretty bad typing support so this client
|
||||
attempts to protect the rest of the codebase from this. As a consequence,
|
||||
most methods here return the minimum data needed for the rest of Onyx, and
|
||||
tend to rely on Exceptions to handle errors.
|
||||
|
||||
TODO(andrei): This class currently assumes the structure of the database
|
||||
schema when it returns a DocumentChunk. Make the class, or at least the
|
||||
search method, templated on the structure the caller can expect.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to interact with.
|
||||
host: The host of the OpenSearch cluster.
|
||||
port: The port of the OpenSearch cluster.
|
||||
auth: The authentication credentials for the OpenSearch cluster. A tuple
|
||||
of (username, password).
|
||||
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
|
||||
True.
|
||||
verify_certs: Whether to verify the SSL certificates for the OpenSearch
|
||||
cluster. Defaults to False.
|
||||
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
|
||||
to False.
|
||||
timeout: The timeout for the OpenSearch cluster. Defaults to
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
host: str = OPENSEARCH_HOST,
|
||||
port: int = OPENSEARCH_REST_API_PORT,
|
||||
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
auth=auth,
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
)
|
||||
@@ -333,38 +192,6 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""
|
||||
return self._client.indices.exists(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_mapping(self, mappings: dict[str, Any]) -> None:
|
||||
"""Updates the index mapping in an idempotent manner.
|
||||
|
||||
- Existing fields with the same definition: No-op (succeeds silently).
|
||||
- New fields: Added to the index.
|
||||
- Existing fields with different types: Raises exception (requires
|
||||
reindex).
|
||||
|
||||
See the OpenSearch documentation for more information:
|
||||
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
|
||||
|
||||
Args:
|
||||
mappings: The complete mapping definition to apply. This will be
|
||||
merged with existing mappings in the index.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the mappings, such as
|
||||
attempting to change the type of an existing field.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Putting mappings for index {self._index_name} with mappings {mappings}."
|
||||
)
|
||||
response = self._client.indices.put_mapping(
|
||||
index=self._index_name, body=mappings
|
||||
)
|
||||
if not response.get("acknowledged", False):
|
||||
raise RuntimeError(
|
||||
f"Failed to put the mapping update for index {self._index_name}."
|
||||
)
|
||||
logger.debug(f"Successfully put mappings for index {self._index_name}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
|
||||
"""Validates the index.
|
||||
@@ -783,6 +610,43 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
)
|
||||
return DocumentChunk.model_validate(document_chunk_source)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def create_search_pipeline(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
pipeline_body: dict[str, Any],
|
||||
) -> None:
|
||||
"""Creates a search pipeline.
|
||||
|
||||
See the OpenSearch documentation for more information on the search
|
||||
pipeline body.
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to create.
|
||||
pipeline_body: The body of the search pipeline to create.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error creating the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def delete_search_pipeline(self, pipeline_id: str) -> None:
|
||||
"""Deletes a search pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_id: The ID of the search pipeline to delete.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error deleting the search pipeline.
|
||||
"""
|
||||
result = self._client.search_pipeline.delete(id=pipeline_id)
|
||||
if not result.get("acknowledged", False):
|
||||
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
@@ -943,6 +807,48 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch could be reached, False if it could not.
|
||||
"""
|
||||
return self._client.ping()
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
@@ -1039,7 +945,14 @@ def wait_for_opensearch_with_timeout(
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
with nullcontext(client) if client else OpenSearchClient() as client:
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
@@ -1056,3 +969,7 @@ def wait_for_opensearch_with_timeout(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -7,7 +7,6 @@ from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -41,7 +40,6 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
@@ -95,25 +93,6 @@ def generate_opensearch_filtered_access_control_list(
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
|
||||
logger.error(
|
||||
"Failed to put cluster settings. If the settings have never been set before, "
|
||||
"this may cause unexpected index creation when indexing documents into an "
|
||||
"index that does not exist, or may cause expected logs to not appear. If this "
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -269,8 +248,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
@@ -281,6 +258,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
if multitenant != MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
|
||||
@@ -288,10 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -300,8 +279,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
# TODO(andrei): Implement.
|
||||
raise NotImplementedError(
|
||||
"Bug: Multitenant index registration is not supported for OpenSearch."
|
||||
"Multitenant index registration is not yet implemented for OpenSearch."
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
@@ -491,37 +471,19 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
for an OpenSearch search engine instance. It handles the complete lifecycle
|
||||
of document chunks within a specific OpenSearch index/schema.
|
||||
|
||||
Each kind of embedding used should correspond to a different instance of
|
||||
this class, and therefore a different index in OpenSearch.
|
||||
|
||||
If in a multitenant environment and
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
|
||||
if necessary on initialization. This is because there is no logic which runs
|
||||
on cluster restart which scans through all search settings over all tenants
|
||||
and creates the relevant indices.
|
||||
|
||||
Args:
|
||||
tenant_state: The tenant state of the caller.
|
||||
index_name: The name of the index to interact with.
|
||||
embedding_dim: The dimensionality of the embeddings used for the index.
|
||||
embedding_precision: The precision of the embeddings used for the index.
|
||||
Although not yet used in this way in the codebase, each kind of embedding
|
||||
used should correspond to a different instance of this class, and therefore
|
||||
a different index in OpenSearch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_state: TenantState,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
tenant_state: TenantState,
|
||||
) -> None:
|
||||
self._index_name: str = index_name
|
||||
self._tenant_state: TenantState = tenant_state
|
||||
self._client = OpenSearchIndexClient(index_name=self._index_name)
|
||||
|
||||
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
|
||||
self.verify_and_create_index_if_necessary(
|
||||
embedding_dim=embedding_dim, embedding_precision=embedding_precision
|
||||
)
|
||||
self._os_client = OpenSearchClient(index_name=self._index_name)
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -530,15 +492,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired cluster settings if not in a multitenant
|
||||
environment.
|
||||
Also puts the desired cluster settings.
|
||||
|
||||
Also puts the desired search pipeline state if not in a multitenant
|
||||
environment, creating the pipelines if they do not exist and updating
|
||||
them otherwise.
|
||||
|
||||
In a multitenant environment, the above steps happen explicitly on
|
||||
setup.
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
@@ -551,38 +508,47 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
|
||||
f"necessary, with embedding dimension {embedding_dim}."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
)
|
||||
|
||||
if not self._tenant_state.multitenant:
|
||||
set_cluster_state(self._client)
|
||||
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
if not self._os_client.put_cluster_settings(
|
||||
settings=OPENSEARCH_CLUSTER_SETTINGS
|
||||
):
|
||||
logger.error(
|
||||
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
|
||||
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
|
||||
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
|
||||
"these settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._client.create_index(
|
||||
self._os_client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
)
|
||||
else:
|
||||
# Ensure schema is up to date by applying the current mappings.
|
||||
try:
|
||||
self._client.put_mapping(expected_mappings)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update mappings for index {self._index_name}. This likely means a "
|
||||
f"field type was changed which requires reindexing. Error: {e}"
|
||||
)
|
||||
raise
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
|
||||
)
|
||||
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
self._os_client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
@@ -654,7 +620,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
self._os_client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
@@ -694,7 +660,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
|
||||
return self._client.delete_by_query(query_body)
|
||||
return self._os_client.delete_by_query(query_body)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -794,7 +760,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document_id=doc_id,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
self._client.update_document(
|
||||
self._os_client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
@@ -833,7 +799,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
search_hits = self._client.search(
|
||||
search_hits = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -883,7 +849,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
|
||||
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
@@ -915,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
@@ -943,6 +909,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# Do not raise if the document already exists, just update. This is
|
||||
# because the document may already have been indexed during the
|
||||
# OpenSearch transition period.
|
||||
self._client.bulk_index_documents(
|
||||
self._os_client.bulk_index_documents(
|
||||
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
|
||||
)
|
||||
|
||||
@@ -405,7 +405,6 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
label_ids: list[int] | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -416,22 +415,14 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
label_ids=persona_share_request.label_ids,
|
||||
)
|
||||
except PermissionError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to share persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -52,8 +55,10 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
@@ -233,12 +238,9 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.name:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
if test_llm_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -268,7 +270,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -303,7 +305,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -328,7 +330,15 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -516,7 +526,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -545,7 +555,13 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
return vision_provider_response
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -555,7 +571,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -592,7 +608,15 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=DefaultModel.from_model_config(
|
||||
fetch_default_llm_model(db_session)
|
||||
),
|
||||
default_vision=DefaultModel.from_model_config(
|
||||
fetch_default_vision_model(db_session)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -21,6 +25,8 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView")
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
@@ -52,19 +58,17 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
id: int | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -425,3 +429,38 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
@classmethod
|
||||
def from_model_config(
|
||||
cls, model_config: ModelConfigurationModel | None
|
||||
) -> DefaultModel | None:
|
||||
if not model_config:
|
||||
return None
|
||||
return cls(
|
||||
provider_id=model_config.llm_provider_id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> LLMProviderResponse[T]:
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
@@ -33,9 +32,6 @@ from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import set_cluster_state
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -315,14 +311,7 @@ def setup_multitenant_onyx() -> None:
|
||||
logger.notice("DISABLE_VECTOR_DB is set — skipping multitenant Vespa setup.")
|
||||
return
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_client = OpenSearchClient()
|
||||
if not wait_for_opensearch_with_timeout(client=opensearch_client):
|
||||
raise RuntimeError("Failed to connect to OpenSearch.")
|
||||
set_cluster_state(opensearch_client)
|
||||
|
||||
# For Managed Vespa, the schema is sent over via the Vespa Console manually.
|
||||
# NOTE: Pretty sure this code is never hit in any production environment.
|
||||
if not MANAGED_VESPA:
|
||||
setup_vespa_multitenant(SUPPORTED_EMBEDDING_MODELS)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -44,9 +45,9 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> None:
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
upsert_llm_provider(
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -102,17 +103,11 @@ class TestLLMConfigurationEndpoint:
|
||||
# This should complete without exception
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None, # New provider (not in DB)
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -152,17 +147,11 @@ class TestLLMConfigurationEndpoint:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -194,7 +183,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -202,17 +193,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -246,7 +232,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -254,17 +242,12 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=True - should use new key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name, # Existing provider
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -297,7 +280,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
upsert_llm_provider(
|
||||
provider = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -321,18 +304,13 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=provider_name,
|
||||
id=provider.id,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -368,17 +346,11 @@ class TestLLMConfigurationEndpoint:
|
||||
for model_name in test_models:
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
name=None,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
|
||||
@@ -530,14 +530,8 @@ def test_upload_with_custom_config_then_change(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -546,7 +540,7 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
@@ -569,14 +563,9 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -616,13 +605,13 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not provider:
|
||||
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
if not db_provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
assert provider.custom_config == custom_config, (
|
||||
assert db_provider.custom_config == custom_config, (
|
||||
f"Expected custom_config {custom_config}, "
|
||||
f"but got {provider.custom_config}"
|
||||
f"but got {db_provider.custom_config}"
|
||||
)
|
||||
finally:
|
||||
db_session.rollback()
|
||||
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
) -> None:
|
||||
"""LLM test should restore masked sensitive custom config values before invocation."""
|
||||
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
|
||||
provider = LlmProviderNames.VERTEX_AI.value
|
||||
provider_name = LlmProviderNames.VERTEX_AI.value
|
||||
default_model_name = "gemini-2.5-pro"
|
||||
original_custom_config = {
|
||||
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
|
||||
@@ -719,10 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
@@ -742,14 +731,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
id=provider.id,
|
||||
provider=provider_name,
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""External dependency unit tests for OpenSearchIndexClient.
|
||||
"""External dependency unit tests for OpenSearchClient.
|
||||
|
||||
These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
@@ -19,7 +19,7 @@ from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -125,10 +125,10 @@ def opensearch_available() -> None:
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
"""Creates an OpenSearch client for testing with automatic cleanup."""
|
||||
test_index_name = f"test_index_{uuid.uuid4().hex[:8]}"
|
||||
client = OpenSearchIndexClient(index_name=test_index_name)
|
||||
client = OpenSearchClient(index_name=test_index_name)
|
||||
|
||||
yield client # Test runs here.
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_client(
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None, None]:
|
||||
def search_pipeline(test_client: OpenSearchClient) -> Generator[None, None, None]:
|
||||
"""Creates a search pipeline for testing with automatic cleanup."""
|
||||
test_client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
@@ -158,9 +158,9 @@ def search_pipeline(test_client: OpenSearchIndexClient) -> Generator[None, None,
|
||||
|
||||
|
||||
class TestOpenSearchClient:
|
||||
"""Tests for OpenSearchIndexClient."""
|
||||
"""Tests for OpenSearchClient."""
|
||||
|
||||
def test_create_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_create_index(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests creating an index with a real schema."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -176,7 +176,7 @@ class TestOpenSearchClient:
|
||||
# Verify index exists.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_delete_existing_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_delete_existing_index(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests deleting an existing index returns True."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -193,7 +193,7 @@ class TestOpenSearchClient:
|
||||
assert result is True
|
||||
assert test_client.validate_index(expected_mappings=mappings) is False
|
||||
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_delete_nonexistent_index(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests deleting a nonexistent index returns False."""
|
||||
# Under test.
|
||||
# Don't create index, just try to delete.
|
||||
@@ -202,7 +202,7 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_index_exists(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_index_exists(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests checking if an index exists."""
|
||||
# Precondition.
|
||||
# Index should not exist before creation.
|
||||
@@ -219,7 +219,7 @@ class TestOpenSearchClient:
|
||||
# Index should exist after creation.
|
||||
assert test_client.index_exists() is True
|
||||
|
||||
def test_validate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_validate_index(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests validating an index."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -239,120 +239,7 @@ class TestOpenSearchClient:
|
||||
# Should return True after creation.
|
||||
assert test_client.validate_index(expected_mappings=mappings) is True
|
||||
|
||||
def test_put_mapping_idempotent(self, test_client: OpenSearchIndexClient) -> None:
|
||||
"""Tests put_mapping with same schema is idempotent."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Applying the same mappings again should succeed.
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Index should still be valid.
|
||||
assert test_client.validate_index(expected_mappings=mappings)
|
||||
|
||||
def test_put_mapping_adds_new_field(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping successfully adds new fields to existing index."""
|
||||
# Precondition.
|
||||
# Create index with minimal schema (just required fields).
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test.
|
||||
# Add a new field using put_mapping.
|
||||
updated_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"chunk_index": {"type": "integer"},
|
||||
"content": {"type": "text"},
|
||||
"content_vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": 128,
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": 512, "m": 16},
|
||||
},
|
||||
},
|
||||
# New field
|
||||
"new_test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
# Should not raise.
|
||||
test_client.put_mapping(updated_mappings)
|
||||
|
||||
# Postcondition.
|
||||
# Validate the new schema includes the new field.
|
||||
assert test_client.validate_index(expected_mappings=updated_mappings)
|
||||
|
||||
def test_put_mapping_fails_on_type_change(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping fails when trying to change existing field type."""
|
||||
# Precondition.
|
||||
initial_mappings = {
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "keyword"},
|
||||
},
|
||||
}
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=initial_mappings, settings=settings)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Try to change test_field type from keyword to text.
|
||||
conflicting_mappings = {
|
||||
"properties": {
|
||||
"document_id": {"type": "keyword"},
|
||||
"test_field": {"type": "text"}, # Changed from keyword to text
|
||||
},
|
||||
}
|
||||
# Should raise because field type cannot be changed.
|
||||
with pytest.raises(Exception, match="mapper|illegal_argument_exception"):
|
||||
test_client.put_mapping(conflicting_mappings)
|
||||
|
||||
def test_put_mapping_on_nonexistent_index(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
) -> None:
|
||||
"""Tests put_mapping on non-existent index raises an error."""
|
||||
# Precondition.
|
||||
# Index does not exist yet.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(Exception, match="index_not_found_exception|404"):
|
||||
test_client.put_mapping(mappings)
|
||||
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_create_duplicate_index(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests creating an index twice raises an error."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
@@ -367,14 +254,14 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
def test_update_settings(self, test_client: OpenSearchIndexClient) -> None:
|
||||
def test_update_settings(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests that update_settings raises NotImplementedError."""
|
||||
# Under test and postcondition.
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_settings(settings={})
|
||||
|
||||
def test_create_and_delete_search_pipeline(
|
||||
self, test_client: OpenSearchIndexClient
|
||||
self, test_client: OpenSearchClient
|
||||
) -> None:
|
||||
"""Tests creating and deleting a search pipeline."""
|
||||
# Under test and postcondition.
|
||||
@@ -391,7 +278,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
@@ -419,7 +306,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_bulk_index_documents(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests bulk indexing documents."""
|
||||
# Precondition.
|
||||
@@ -450,7 +337,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
@@ -478,7 +365,7 @@ class TestOpenSearchClient:
|
||||
test_client.index_document(document=doc, tenant_state=tenant_state)
|
||||
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
@@ -514,7 +401,7 @@ class TestOpenSearchClient:
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -532,7 +419,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
@@ -568,7 +455,7 @@ class TestOpenSearchClient:
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
@@ -589,7 +476,7 @@ class TestOpenSearchClient:
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
@@ -665,7 +552,7 @@ class TestOpenSearchClient:
|
||||
assert len(keep_ids) == 1
|
||||
|
||||
def test_update_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a document's properties."""
|
||||
# Precondition.
|
||||
@@ -714,7 +601,7 @@ class TestOpenSearchClient:
|
||||
assert updated_doc.public == doc.public
|
||||
|
||||
def test_update_nonexistent_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests updating a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
@@ -736,7 +623,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -817,7 +704,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_search_empty_index(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -856,7 +743,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -976,7 +863,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_hybrid_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1106,7 +993,7 @@ class TestOpenSearchClient:
|
||||
previous_score = current_score
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
@@ -1200,7 +1087,7 @@ class TestOpenSearchClient:
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
@@ -1229,7 +1116,7 @@ class TestOpenSearchClient:
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
@@ -1294,7 +1181,7 @@ class TestOpenSearchClient:
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_search_with_no_document_access_can_retrieve_all_documents(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with no document access can retrieve all documents, even
|
||||
@@ -1372,7 +1259,7 @@ class TestOpenSearchClient:
|
||||
|
||||
def test_time_cutoff_filter(
|
||||
self,
|
||||
test_client: OpenSearchIndexClient,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None, # noqa: ARG002
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@@ -1465,7 +1352,7 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
def test_random_search(
|
||||
self, test_client: OpenSearchIndexClient, monkeypatch: pytest.MonkeyPatch
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests the random search query works."""
|
||||
# Precondition.
|
||||
|
||||
@@ -37,7 +37,6 @@ from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapp
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import OpenSearchIndexClient
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
@@ -75,7 +74,7 @@ CHUNK_COUNT = 5
|
||||
|
||||
|
||||
def _get_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
) -> list[DocumentChunk]:
|
||||
opensearch_client.refresh_index()
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=current_tenant_id)
|
||||
@@ -96,7 +95,7 @@ def _get_document_chunks_from_opensearch(
|
||||
|
||||
|
||||
def _delete_document_chunks_from_opensearch(
|
||||
opensearch_client: OpenSearchIndexClient, document_id: str, current_tenant_id: str
|
||||
opensearch_client: OpenSearchClient, document_id: str, current_tenant_id: str
|
||||
) -> None:
|
||||
opensearch_client.refresh_index()
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
@@ -284,10 +283,10 @@ def vespa_document_index(
|
||||
def opensearch_client(
|
||||
db_session: Session,
|
||||
full_deployment_setup: None, # noqa: ARG001
|
||||
) -> Generator[OpenSearchIndexClient, None, None]:
|
||||
) -> Generator[OpenSearchClient, None, None]:
|
||||
"""Creates an OpenSearch client for the test tenant."""
|
||||
active = get_active_search_settings(db_session)
|
||||
yield OpenSearchIndexClient(index_name=active.primary.index_name) # Test runs here.
|
||||
yield OpenSearchClient(index_name=active.primary.index_name) # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -331,7 +330,7 @@ def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
def test_documents(
|
||||
db_session: Session,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
opensearch_client: OpenSearchClient,
|
||||
patch_get_vespa_chunks_page_size: int,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
"""
|
||||
@@ -412,7 +411,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
opensearch_client: OpenSearchClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -481,7 +480,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
opensearch_client: OpenSearchClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -619,7 +618,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
opensearch_client: OpenSearchClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
@@ -713,7 +712,7 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
opensearch_client: OpenSearchClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.auth.oauth_token_manager import OAuthTokenManager
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.oauth_config import create_oauth_config
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
|
||||
|
||||
@@ -492,19 +491,3 @@ class TestOAuthTokenManagerURLBuilding:
|
||||
# Should use & instead of ? since URL already has query params
|
||||
assert "foo=bar&" in url or "?foo=bar" in url
|
||||
assert "client_id=custom_client_id" in url
|
||||
|
||||
|
||||
class TestUnwrapSensitiveStr:
|
||||
"""Tests for _unwrap_sensitive_str static method"""
|
||||
|
||||
def test_unwrap_sensitive_str(self) -> None:
|
||||
"""Test that both SensitiveValue and plain str inputs are handled"""
|
||||
# SensitiveValue input
|
||||
sensitive = SensitiveValue[str](
|
||||
encrypted_bytes=b"test_client_id",
|
||||
decrypt_fn=lambda b: b.decode(),
|
||||
)
|
||||
assert OAuthTokenManager._unwrap_sensitive_str(sensitive) == "test_client_id"
|
||||
|
||||
# Plain str input
|
||||
assert OAuthTokenManager._unwrap_sensitive_str("plain_string") == "plain_string"
|
||||
|
||||
@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json():
|
||||
for provider in response.json()["providers"]:
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
@@ -578,6 +578,57 @@ def test_model_visibility_preserved_on_edit(reset: None) -> None: # noqa: ARG00
|
||||
assert visible_models[0]["name"] == "gpt-4o"
|
||||
|
||||
|
||||
def _get_provider_by_name(providers: list[dict], provider_name: str) -> dict | None:
|
||||
return next((p for p in providers if p["name"] == provider_name), None)
|
||||
|
||||
|
||||
def _get_providers_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> dict | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
resp_json = response.json()
|
||||
|
||||
return resp_json
|
||||
|
||||
|
||||
def _unpack_data(data: dict) -> tuple[list[dict], dict | None, dict | None]:
|
||||
providers = data["providers"]
|
||||
text_default = data.get("default_text")
|
||||
vision_default = data.get("default_vision")
|
||||
|
||||
return providers, text_default, vision_default
|
||||
|
||||
|
||||
def _get_providers_basic(
|
||||
user: DATestUser,
|
||||
) -> dict | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
resp_json = response.json()
|
||||
|
||||
return resp_json
|
||||
|
||||
|
||||
def _validate_default_model(
|
||||
default: dict | None,
|
||||
provider_id: int | None = None,
|
||||
model_name: str | None = None,
|
||||
) -> None:
|
||||
if default is None:
|
||||
assert provider_id is None and model_name is None
|
||||
return
|
||||
|
||||
assert default["provider_id"] == provider_id
|
||||
assert default["model_name"] == model_name
|
||||
|
||||
|
||||
def _get_provider_by_name_admin(
|
||||
admin_user: DATestUser, provider_name: str
|
||||
) -> dict | None:
|
||||
@@ -598,7 +649,7 @@ def _get_provider_by_name_basic(user: DATestUser, provider_name: str) -> dict |
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
return next((p for p in providers if p["name"] == provider_name), None)
|
||||
|
||||
|
||||
@@ -782,9 +833,22 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
|
||||
# Capture initial defaults (setup_postgres may have created a DevEnvPresetOpenAI default)
|
||||
initial_data = _get_providers_admin(admin_user)
|
||||
assert initial_data is not None
|
||||
_, initial_text_default, initial_vision_default = _unpack_data(initial_data)
|
||||
|
||||
# Step 2: Verify via admin endpoint that all provider data is correct
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
# Defaults should be unchanged from initial state (new provider not set as default)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
expected_name=provider_name,
|
||||
@@ -797,7 +861,13 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 3: Verify via basic endpoint (admin user) that all provider data is correct
|
||||
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
@@ -810,7 +880,13 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 4: Verify non-admin user sees the same provider data via basic endpoint
|
||||
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
|
||||
basic_user_data = _get_providers_basic(basic_user)
|
||||
assert basic_user_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_user_data)
|
||||
assert text_default == initial_text_default
|
||||
assert vision_default == initial_vision_default
|
||||
|
||||
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert basic_user_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_user_provider_data,
|
||||
@@ -846,7 +922,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
assert default_provider_response.status_code == 200
|
||||
|
||||
# Step 6a: Verify the updated provider data via admin endpoint
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_data,
|
||||
@@ -860,7 +946,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 6b: Verify the updated provider data via basic endpoint (admin user)
|
||||
admin_basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_provider_data,
|
||||
@@ -873,7 +969,17 @@ def test_default_model_persistence_and_update(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Step 7: Verify non-admin user sees the updated provider data
|
||||
basic_user_provider_data = _get_provider_by_name_basic(basic_user, provider_name)
|
||||
basic_user_data = _get_providers_basic(basic_user)
|
||||
assert basic_user_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_user_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=update_response.json()["id"],
|
||||
model_name=updated_default_model,
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
|
||||
basic_user_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert basic_user_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_user_provider_data,
|
||||
@@ -893,7 +999,7 @@ def _get_all_providers_basic(user: DATestUser) -> list[dict]:
|
||||
headers=user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()
|
||||
return response.json()["providers"]
|
||||
|
||||
|
||||
def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
|
||||
@@ -903,7 +1009,7 @@ def _get_all_providers_admin(admin_user: DATestUser) -> list[dict]:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()
|
||||
return response.json()["providers"]
|
||||
|
||||
|
||||
def _set_default_provider(admin_user: DATestUser, provider_id: int) -> None:
|
||||
@@ -1039,11 +1145,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 3: Both admin and basic_user query and verify they see the same default
|
||||
# Validate via admin endpoint
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_default,
|
||||
admin_provider_data,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1054,9 +1166,7 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate provider 2 via admin endpoint (should not be default)
|
||||
admin_provider_2 = next(
|
||||
(p for p in admin_providers if p["name"] == provider_2_name), None
|
||||
)
|
||||
admin_provider_2 = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_provider_2 is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_2,
|
||||
@@ -1070,11 +1180,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_default,
|
||||
basic_provider_data,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1084,11 +1200,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Also verify admin sees the same via basic endpoint
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_1["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_default,
|
||||
admin_basic_provider_data,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1121,11 +1243,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 5: Both admin and basic_user verify they see the updated default
|
||||
# Validate via admin endpoint
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_default,
|
||||
admin_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1136,9 +1264,7 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate provider 1 via admin endpoint (should no longer be default)
|
||||
admin_provider_1 = next(
|
||||
(p for p in admin_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_provider_1 is not None
|
||||
_validate_provider_data(
|
||||
admin_provider_1,
|
||||
@@ -1152,11 +1278,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_default,
|
||||
basic_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1166,11 +1298,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (admin_user)
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=provider_2_unique_model
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_default,
|
||||
admin_basic_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=provider_2_unique_model,
|
||||
@@ -1200,11 +1338,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
|
||||
# Step 7: Both users verify they see provider 2 as default with the shared model name
|
||||
# Validate via admin endpoint
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
assert admin_default is not None
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_default,
|
||||
admin_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1215,11 +1359,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (basic_user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
assert basic_default is not None
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_default,
|
||||
basic_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1229,11 +1379,17 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Validate via basic endpoint (admin_user)
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
assert admin_basic_default is not None
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=provider_2["id"], model_name=shared_model_name
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_basic_provider_data = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_default,
|
||||
admin_basic_provider_data,
|
||||
expected_name=provider_2_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1243,12 +1399,10 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
)
|
||||
|
||||
# Verify provider 1 is no longer the default and has correct data
|
||||
provider_1_admin = next(
|
||||
(p for p in admin_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
assert provider_1_admin is not None
|
||||
admin_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_provider_1 is not None
|
||||
_validate_provider_data(
|
||||
provider_1_admin,
|
||||
admin_provider_1,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1258,12 +1412,10 @@ def test_multiple_providers_default_switching(reset: None) -> None: # noqa: ARG
|
||||
expected_is_public=True,
|
||||
)
|
||||
|
||||
provider_1_basic = next(
|
||||
(p for p in basic_providers if p["name"] == provider_1_name), None
|
||||
)
|
||||
assert provider_1_basic is not None
|
||||
basic_provider_1 = _get_provider_by_name(providers, provider_1_name)
|
||||
assert basic_provider_1 is not None
|
||||
_validate_provider_data(
|
||||
provider_1_basic,
|
||||
basic_provider_1,
|
||||
expected_name=provider_1_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_default_model=shared_model_name,
|
||||
@@ -1391,10 +1543,22 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 5: Verify via admin endpoint
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
|
||||
# Find and validate the default provider (provider 1)
|
||||
admin_default = _find_default_provider(admin_providers)
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
admin_default = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_default is not None
|
||||
_validate_provider_data(
|
||||
admin_default,
|
||||
@@ -1409,7 +1573,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Find and validate the default vision provider (provider 2)
|
||||
admin_vision_default = _find_default_vision_provider(admin_providers)
|
||||
admin_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_vision_default is not None
|
||||
_validate_provider_data(
|
||||
admin_vision_default,
|
||||
@@ -1425,10 +1589,21 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 6: Verify via basic endpoint (basic_user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
|
||||
# Find and validate the default provider (provider 1)
|
||||
basic_default = _find_default_provider(basic_providers)
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
basic_default = _get_provider_by_name(providers, provider_1_name)
|
||||
assert basic_default is not None
|
||||
_validate_provider_data(
|
||||
basic_default,
|
||||
@@ -1442,7 +1617,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Find and validate the default vision provider (provider 2)
|
||||
basic_vision_default = _find_default_vision_provider(basic_providers)
|
||||
basic_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
assert basic_vision_default is not None
|
||||
_validate_provider_data(
|
||||
basic_vision_default,
|
||||
@@ -1457,9 +1632,20 @@ def test_default_provider_and_vision_provider_selection(
|
||||
)
|
||||
|
||||
# Step 7: Verify via basic endpoint (admin_user sees same as basic_user)
|
||||
admin_basic_providers = _get_all_providers_basic(admin_user)
|
||||
|
||||
admin_basic_default = _find_default_provider(admin_basic_providers)
|
||||
admin_basic_data = _get_providers_basic(admin_user)
|
||||
assert admin_basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_basic_data)
|
||||
_validate_default_model(
|
||||
text_default,
|
||||
provider_id=provider_1["id"],
|
||||
model_name=provider_1_non_vision_model,
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=provider_2["id"],
|
||||
model_name=provider_2_vision_model_1,
|
||||
)
|
||||
admin_basic_default = _get_provider_by_name(providers, provider_1_name)
|
||||
assert admin_basic_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_default,
|
||||
@@ -1472,7 +1658,7 @@ def test_default_provider_and_vision_provider_selection(
|
||||
expected_is_default_vision=None,
|
||||
)
|
||||
|
||||
admin_basic_vision_default = _find_default_vision_provider(admin_basic_providers)
|
||||
admin_basic_vision_default = _get_provider_by_name(providers, provider_2_name)
|
||||
assert admin_basic_vision_default is not None
|
||||
_validate_provider_data(
|
||||
admin_basic_vision_default,
|
||||
@@ -1549,7 +1735,14 @@ def test_default_provider_is_not_default_vision_provider(
|
||||
_set_default_provider(admin_user, created_provider["id"])
|
||||
|
||||
# Step 3 & 4: Verify via admin endpoint
|
||||
admin_provider_data = _get_provider_by_name_admin(admin_user, provider_name)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=created_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
admin_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert admin_provider_data is not None
|
||||
|
||||
# Verify it IS the default provider
|
||||
@@ -1584,7 +1777,14 @@ def test_default_provider_is_not_default_vision_provider(
|
||||
)
|
||||
|
||||
# Also verify via basic endpoint
|
||||
basic_provider_data = _get_provider_by_name_basic(admin_user, provider_name)
|
||||
basic_data = _get_providers_basic(admin_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=created_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(vision_default) # None
|
||||
basic_provider_data = _get_provider_by_name(providers, provider_name)
|
||||
assert basic_provider_data is not None
|
||||
|
||||
assert (
|
||||
@@ -1769,37 +1969,52 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
|
||||
# Step 4: Verify all three types are correctly tracked
|
||||
|
||||
# Get all LLM providers (via admin endpoint)
|
||||
admin_providers = _get_all_providers_admin(admin_user)
|
||||
admin_data = _get_providers_admin(admin_user)
|
||||
assert admin_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(admin_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=vision_provider["id"],
|
||||
model_name="gpt-4-vision-preview",
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default, vision_provider["id"], "gpt-4-vision-preview"
|
||||
)
|
||||
_get_provider_by_name(providers, regular_provider_name)
|
||||
|
||||
# Get all image generation configs
|
||||
image_gen_configs = _get_all_image_gen_configs(admin_user)
|
||||
|
||||
# Verify the regular provider is the default provider
|
||||
regular_provider_data = next(
|
||||
(p for p in admin_providers if p["name"] == regular_provider_name), None
|
||||
admin_regular_provider_data = _get_provider_by_name(
|
||||
providers, regular_provider_name
|
||||
)
|
||||
assert regular_provider_data is not None, "Regular provider not found"
|
||||
assert (
|
||||
regular_provider_data["is_default_provider"] is True
|
||||
), "Regular provider should be the default provider"
|
||||
assert (
|
||||
regular_provider_data.get("is_default_vision_provider") is not True
|
||||
), "Regular provider should NOT be the default vision provider"
|
||||
|
||||
# Verify the vision provider is the default vision provider
|
||||
vision_provider_data = next(
|
||||
(p for p in admin_providers if p["name"] == vision_provider_name), None
|
||||
assert admin_regular_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_regular_provider_data,
|
||||
expected_name=regular_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in regular_model_configs],
|
||||
expected_visible={c.name: True for c in regular_model_configs},
|
||||
expected_default_model="gpt-4",
|
||||
expected_is_default=True,
|
||||
)
|
||||
admin_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
|
||||
assert admin_vision_provider_data is not None
|
||||
_validate_provider_data(
|
||||
admin_vision_provider_data,
|
||||
expected_name=vision_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in vision_model_configs],
|
||||
expected_visible={c.name: True for c in vision_model_configs},
|
||||
expected_default_model="gpt-4-vision-preview",
|
||||
expected_is_default=False,
|
||||
expected_default_vision_model="gpt-4-vision-preview",
|
||||
expected_is_default_vision=True,
|
||||
)
|
||||
assert vision_provider_data is not None, "Vision provider not found"
|
||||
assert (
|
||||
vision_provider_data.get("is_default_provider") is not True
|
||||
), "Vision provider should NOT be the default provider"
|
||||
assert (
|
||||
vision_provider_data["is_default_vision_provider"] is True
|
||||
), "Vision provider should be the default vision provider"
|
||||
assert (
|
||||
vision_provider_data["default_vision_model"] == "gpt-4-vision-preview"
|
||||
), "Vision provider should have correct default vision model"
|
||||
|
||||
# Verify the image gen config is the default image generation config
|
||||
image_gen_config_data = next(
|
||||
@@ -1819,97 +2034,53 @@ def test_all_three_provider_types_no_mixup(reset: None) -> None: # noqa: ARG001
|
||||
), "Image gen config should have correct model name"
|
||||
|
||||
# Step 5: Verify no mixup - image gen providers don't appear in LLM provider lists
|
||||
|
||||
# The image gen config creates an LLM provider with name "Image Gen - {image_provider_id}"
|
||||
# This should NOT be returned by the regular LLM provider endpoints
|
||||
[p["name"] for p in admin_providers]
|
||||
image_gen_llm_provider_name = f"Image Gen - {image_gen_provider_id}"
|
||||
|
||||
# Note: The image gen provider IS an LLM provider internally, so it may appear in the list
|
||||
# But it should NOT be marked as default provider or default vision provider
|
||||
image_gen_llm_provider = next(
|
||||
(p for p in admin_providers if p["name"] == image_gen_llm_provider_name), None
|
||||
)
|
||||
if image_gen_llm_provider:
|
||||
# If it appears, verify it's not marked as default for either type
|
||||
assert (
|
||||
image_gen_llm_provider.get("is_default_provider") is not True
|
||||
), "Image gen's internal LLM provider should NOT be the default provider"
|
||||
assert (
|
||||
image_gen_llm_provider.get("is_default_vision_provider") is not True
|
||||
), "Image gen's internal LLM provider should NOT be the default vision provider"
|
||||
# Image gen provider should not appear in the list
|
||||
assert image_gen_provider_id not in [p["name"] for p in providers]
|
||||
|
||||
# Step 6: Verify via basic endpoint (non-admin user)
|
||||
basic_providers = _get_all_providers_basic(basic_user)
|
||||
|
||||
# Verify regular provider is default for basic user
|
||||
basic_regular = next(
|
||||
(p for p in basic_providers if p["name"] == regular_provider_name), None
|
||||
basic_data = _get_providers_basic(basic_user)
|
||||
assert basic_data is not None
|
||||
providers, text_default, vision_default = _unpack_data(basic_data)
|
||||
_validate_default_model(
|
||||
text_default, provider_id=regular_provider["id"], model_name="gpt-4"
|
||||
)
|
||||
assert basic_regular is not None, "Regular provider not visible to basic user"
|
||||
assert (
|
||||
basic_regular["is_default_provider"] is True
|
||||
), "Regular provider should be default for basic user"
|
||||
|
||||
# Verify vision provider is default vision for basic user
|
||||
basic_vision = next(
|
||||
(p for p in basic_providers if p["name"] == vision_provider_name), None
|
||||
_validate_default_model(
|
||||
vision_default,
|
||||
provider_id=vision_provider["id"],
|
||||
model_name="gpt-4-vision-preview",
|
||||
)
|
||||
_validate_default_model(
|
||||
vision_default, vision_provider["id"], "gpt-4-vision-preview"
|
||||
)
|
||||
basic_provider_data = _get_provider_by_name(providers, regular_provider_name)
|
||||
assert basic_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_provider_data,
|
||||
expected_name=regular_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in regular_model_configs],
|
||||
expected_visible={c.name: True for c in regular_model_configs},
|
||||
expected_is_default=True,
|
||||
expected_default_model="gpt-4",
|
||||
)
|
||||
basic_vision_provider_data = _get_provider_by_name(providers, vision_provider_name)
|
||||
assert basic_vision_provider_data is not None
|
||||
_validate_provider_data(
|
||||
basic_vision_provider_data,
|
||||
expected_name=vision_provider_name,
|
||||
expected_provider=LlmProviderNames.OPENAI,
|
||||
expected_model_names=[c.name for c in vision_model_configs],
|
||||
expected_visible={c.name: True for c in vision_model_configs},
|
||||
expected_is_default=False,
|
||||
expected_default_model="gpt-4-vision-preview",
|
||||
expected_default_vision_model="gpt-4-vision-preview",
|
||||
expected_is_default_vision=True,
|
||||
)
|
||||
assert basic_vision is not None, "Vision provider not visible to basic user"
|
||||
assert (
|
||||
basic_vision["is_default_vision_provider"] is True
|
||||
), "Vision provider should be default vision for basic user"
|
||||
|
||||
# Step 7: Verify the counts are as expected
|
||||
# We should have at least 2 user-created providers plus the image gen internal provider
|
||||
user_created_providers = [
|
||||
p
|
||||
for p in admin_providers
|
||||
if p["name"] in [regular_provider_name, vision_provider_name]
|
||||
]
|
||||
assert (
|
||||
len(user_created_providers) == 2
|
||||
), f"Expected 2 user-created providers, got {len(user_created_providers)}"
|
||||
|
||||
# We should have exactly 1 image gen config
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
c
|
||||
for c in image_gen_configs
|
||||
if c["image_provider_id"] == image_gen_provider_id
|
||||
]
|
||||
)
|
||||
== 1
|
||||
), "Expected exactly 1 image gen config with our ID"
|
||||
|
||||
# Verify that our explicitly created providers are tracked correctly:
|
||||
# - Only ONE provider has is_default_provider=True
|
||||
default_providers = [
|
||||
p for p in admin_providers if p.get("is_default_provider") is True
|
||||
]
|
||||
assert (
|
||||
len(default_providers) == 1
|
||||
), f"Expected exactly 1 default provider, got {len(default_providers)}"
|
||||
assert default_providers[0]["name"] == regular_provider_name
|
||||
|
||||
# - Only ONE provider has is_default_vision_provider=True
|
||||
default_vision_providers = [
|
||||
p for p in admin_providers if p.get("is_default_vision_provider") is True
|
||||
]
|
||||
assert (
|
||||
len(default_vision_providers) == 1
|
||||
), f"Expected exactly 1 default vision provider, got {len(default_vision_providers)}"
|
||||
assert default_vision_providers[0]["name"] == vision_provider_name
|
||||
|
||||
# - Only ONE image gen config has is_default=True
|
||||
default_image_gen_configs = [
|
||||
c for c in image_gen_configs if c.get("is_default") is True
|
||||
]
|
||||
assert (
|
||||
len(default_image_gen_configs) == 1
|
||||
), f"Expected exactly 1 default image gen config, got {len(default_image_gen_configs)}"
|
||||
assert default_image_gen_configs[0]["image_provider_id"] == image_gen_provider_id
|
||||
# We should have at least 2 user-created providers (setup_postgres may add more)
|
||||
assert len(providers) >= 2
|
||||
assert len(image_gen_configs) == 1
|
||||
|
||||
# Clean up: Delete the image gen config (to clean up the internal LLM provider)
|
||||
_delete_image_gen_config(admin_user, image_gen_provider_id)
|
||||
|
||||
@@ -272,6 +272,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
# First, clear any existing CHAT defaults (setup_postgres may have created one)
|
||||
existing_defaults = (
|
||||
db_session.query(LLMModelFlow)
|
||||
.filter(
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CHAT,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for existing in existing_defaults:
|
||||
existing.is_default = False
|
||||
db_session.flush()
|
||||
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
@@ -365,7 +378,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -380,7 +393,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
|
||||
@@ -45,25 +45,15 @@ def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _parse_models_env(env_var: str) -> list[str]:
|
||||
raw_value = os.environ.get(env_var, "").strip()
|
||||
if not raw_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(raw_value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_json = None
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
return [str(model).strip() for model in parsed_json if str(model).strip()]
|
||||
|
||||
return [part.strip() for part in raw_value.split(",") if part.strip()]
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _parse_models_env(_ENV_MODELS)
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
@@ -323,21 +313,10 @@ def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
failures: list[str] = []
|
||||
for model_name in config.model_names:
|
||||
try:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, (KeyboardInterrupt, SystemExit)):
|
||||
raise
|
||||
failures.append(
|
||||
f"provider={config.provider} model={model_name} error={type(exc).__name__}: {exc}"
|
||||
)
|
||||
|
||||
if failures:
|
||||
pytest.fail("Nightly provider chat failures:\n" + "\n".join(failures))
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
|
||||
@@ -534,10 +534,9 @@ services:
|
||||
required: false
|
||||
|
||||
# Below is needed for the `docker-out-of-docker` execution mode
|
||||
# For Linux rootless Docker, set DOCKER_SOCK_PATH=${XDG_RUNTIME_DIR}/docker.sock
|
||||
user: root
|
||||
volumes:
|
||||
- ${DOCKER_SOCK_PATH:-/var/run/docker.sock}:/var/run/docker.sock
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
|
||||
# uncomment below + comment out the above to use the `docker-in-docker` execution mode
|
||||
# privileged: true
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import "@opal/core/hoverable/styles.css";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context-per-group registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Lazily-created map of group names to React contexts.
|
||||
*
|
||||
* Each group gets its own `React.Context<boolean | null>` so that a
|
||||
* `Hoverable.Item` only re-renders when its *own* group's hover state
|
||||
* changes — not when any unrelated group changes.
|
||||
*
|
||||
* The default value is `null` (no provider found), which lets
|
||||
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
|
||||
* not hovered" and throw when `group` was explicitly specified.
|
||||
*/
|
||||
const contextMap = new Map<string, React.Context<boolean | null>>();
|
||||
|
||||
function getOrCreateContext(group: string): React.Context<boolean | null> {
|
||||
let ctx = contextMap.get(group);
|
||||
if (!ctx) {
|
||||
ctx = createContext<boolean | null>(null);
|
||||
ctx.displayName = `HoverableContext(${group})`;
|
||||
contextMap.set(group, ctx);
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HoverableRootProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group: string;
|
||||
}
|
||||
|
||||
type HoverableItemVariant = "opacity-on-hover";
|
||||
|
||||
interface HoverableItemProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
children: React.ReactNode;
|
||||
group?: string;
|
||||
variant?: HoverableItemVariant;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableRoot
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hover-tracking container for a named group.
|
||||
*
|
||||
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
|
||||
* provides the hover state via a per-group React context.
|
||||
*
|
||||
* Nesting works because each `Hoverable.Root` creates a **new** context
|
||||
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
|
||||
* reads from the inner provider, not the outer `group="a"` provider.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*/
|
||||
function HoverableRoot({
|
||||
group,
|
||||
children,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(true);
|
||||
consumerMouseEnter?.(e);
|
||||
},
|
||||
[consumerMouseEnter]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(false);
|
||||
consumerMouseLeave?.(e);
|
||||
},
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<GroupContext.Provider value={hovered}>
|
||||
<div {...props} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{children}
|
||||
</div>
|
||||
</GroupContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HoverableItem
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* An element whose visibility is controlled by hover state.
|
||||
*
|
||||
* **Local mode** (`group` omitted): the item handles hover on its own
|
||||
* element via CSS `:hover`. This is the core abstraction.
|
||||
*
|
||||
* **Group mode** (`group` provided): visibility is driven by a matching
|
||||
* `Hoverable.Root` ancestor's hover state via React context. If no
|
||||
* matching Root is found, an error is thrown.
|
||||
*
|
||||
* Uses data-attributes for variant styling (see `styles.css`).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* // Local mode — hover on the item itself
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
*
|
||||
* // Group mode — hover on the Root reveals the item
|
||||
* <Hoverable.Root group="card">
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*
|
||||
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
|
||||
*/
|
||||
function HoverableItem({
|
||||
group,
|
||||
variant = "opacity-on-hover",
|
||||
children,
|
||||
...props
|
||||
}: HoverableItemProps) {
|
||||
const contextValue = useContext(
|
||||
group ? getOrCreateContext(group) : NOOP_CONTEXT
|
||||
);
|
||||
|
||||
if (group && contextValue === null) {
|
||||
throw new Error(
|
||||
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
|
||||
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
|
||||
);
|
||||
}
|
||||
|
||||
const isLocal = group === undefined;
|
||||
|
||||
return (
|
||||
<div
|
||||
{...props}
|
||||
className={cn("hoverable-item")}
|
||||
data-hoverable-variant={variant}
|
||||
data-hoverable-active={
|
||||
isLocal ? undefined : contextValue ? "true" : undefined
|
||||
}
|
||||
data-hoverable-local={isLocal ? "true" : undefined}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Stable context used when no group is specified (local mode). */
|
||||
const NOOP_CONTEXT = createContext<boolean | null>(null);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compound export
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Hoverable compound component for hover-to-reveal patterns.
|
||||
*
|
||||
* Provides two sub-components:
|
||||
*
|
||||
* - `Hoverable.Root` — A container that tracks hover state for a named group
|
||||
* and provides it via React context.
|
||||
*
|
||||
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
|
||||
* applies local CSS `:hover` for the variant effect. When `group` is
|
||||
* specified, it reads hover state from the nearest matching
|
||||
* `Hoverable.Root` — and throws if no matching Root is found.
|
||||
*
|
||||
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
|
||||
* so each group's items only respond to their own root's hover.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* import { Hoverable } from "@opal/core";
|
||||
*
|
||||
* // Group mode — hovering the card reveals the trash icon
|
||||
* <Hoverable.Root group="card">
|
||||
* <Card>
|
||||
* <span>Card content</span>
|
||||
* <Hoverable.Item group="card" variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* </Card>
|
||||
* </Hoverable.Root>
|
||||
*
|
||||
* // Local mode — hovering the item itself reveals it
|
||||
* <Hoverable.Item variant="opacity-on-hover">
|
||||
* <TrashIcon />
|
||||
* </Hoverable.Item>
|
||||
* ```
|
||||
*/
|
||||
const Hoverable = {
|
||||
Root: HoverableRoot,
|
||||
Item: HoverableItem,
|
||||
};
|
||||
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
};
|
||||
@@ -1,18 +0,0 @@
|
||||
/* Hoverable — item transitions */
|
||||
.hoverable-item {
|
||||
transition: opacity 200ms ease-in-out;
|
||||
}
|
||||
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Group mode — Root controls visibility via React context */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Local mode — item handles its own :hover */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-local="true"]:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
@@ -1,11 +1,3 @@
|
||||
/* Hoverable */
|
||||
export {
|
||||
Hoverable,
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
} from "@opal/core/hoverable/components";
|
||||
|
||||
/* Interactive */
|
||||
export {
|
||||
Interactive,
|
||||
|
||||
@@ -12,7 +12,7 @@ const SvgOrganization = ({ size, ...props }: IconProps) => (
|
||||
>
|
||||
<path
|
||||
d="M7.5 14H13.5C14.0523 14 14.5 13.5523 14.5 13V6C14.5 5.44772 14.0523 5 13.5 5H7.5M7.5 14V11M7.5 14H4.5M7.5 5V3C7.5 2.44772 7.05228 2 6.5 2H4.5M7.5 5H1.5M7.5 5V8M1.5 5V3C1.5 2.44772 1.94772 2 2.5 2H4.5M1.5 5V8M7.5 8V11M7.5 8H4.5M1.5 8V11M1.5 8H4.5M7.5 11H4.5M1.5 11V13C1.5 13.5523 1.94772 14 2.5 14H4.5M1.5 11H4.5M4.5 2V8M4.5 14V11M4.5 11V8M10 8H12M10 11H12"
|
||||
strokeWidth={1.5}
|
||||
strokeWidth={1}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
SvgPlus,
|
||||
SvgWallet,
|
||||
SvgFileText,
|
||||
SvgOrganization,
|
||||
} from "@opal/icons";
|
||||
import { BillingInformation, LicenseStatus } from "@/lib/billing/interfaces";
|
||||
import {
|
||||
@@ -144,20 +143,17 @@ function SubscriptionCard({
|
||||
license,
|
||||
onViewPlans,
|
||||
disabled,
|
||||
isManualLicenseOnly,
|
||||
onReconnect,
|
||||
}: {
|
||||
billing?: BillingInformation;
|
||||
license?: LicenseStatus;
|
||||
onViewPlans: () => void;
|
||||
disabled?: boolean;
|
||||
isManualLicenseOnly?: boolean;
|
||||
onReconnect?: () => Promise<void>;
|
||||
}) {
|
||||
const [isReconnecting, setIsReconnecting] = useState(false);
|
||||
|
||||
const planName = isManualLicenseOnly ? "Enterprise Plan" : "Business Plan";
|
||||
const PlanIcon = isManualLicenseOnly ? SvgOrganization : SvgUsers;
|
||||
const planName = "Business Plan";
|
||||
const expirationDate = billing?.current_period_end ?? license?.expires_at;
|
||||
const formattedDate = formatDateShort(expirationDate);
|
||||
|
||||
@@ -215,7 +211,7 @@ function SubscriptionCard({
|
||||
height="auto"
|
||||
>
|
||||
<Section gap={0.25} alignItems="start" height="auto" width="auto">
|
||||
<PlanIcon className="w-5 h-5" />
|
||||
<SvgUsers className="w-5 h-5 stroke-text-03" />
|
||||
<Text headingH3Muted text04>
|
||||
{planName}
|
||||
</Text>
|
||||
@@ -230,19 +226,7 @@ function SubscriptionCard({
|
||||
height="auto"
|
||||
width="fit"
|
||||
>
|
||||
{isManualLicenseOnly ? (
|
||||
<Text secondaryBody text03 className="text-right">
|
||||
Your plan is managed through sales.
|
||||
<br />
|
||||
<a
|
||||
href="mailto:support@onyx.app?subject=Billing%20change%20request"
|
||||
className="underline"
|
||||
>
|
||||
Contact billing
|
||||
</a>{" "}
|
||||
to make changes.
|
||||
</Text>
|
||||
) : disabled ? (
|
||||
{disabled ? (
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
@@ -282,13 +266,11 @@ function SeatsCard({
|
||||
license,
|
||||
onRefresh,
|
||||
disabled,
|
||||
hideUpdateSeats,
|
||||
}: {
|
||||
billing?: BillingInformation;
|
||||
license?: LicenseStatus;
|
||||
onRefresh?: () => Promise<void>;
|
||||
disabled?: boolean;
|
||||
hideUpdateSeats?: boolean;
|
||||
}) {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
@@ -502,17 +484,15 @@ function SeatsCard({
|
||||
<Button main tertiary href="/admin/users" leftIcon={SvgExternalLink}>
|
||||
View Users
|
||||
</Button>
|
||||
{!hideUpdateSeats && (
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
onClick={handleStartEdit}
|
||||
leftIcon={SvgPlus}
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
>
|
||||
Update Seats
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
main
|
||||
secondary
|
||||
onClick={handleStartEdit}
|
||||
leftIcon={SvgPlus}
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
>
|
||||
Update Seats
|
||||
</Button>
|
||||
</Section>
|
||||
</Section>
|
||||
</Card>
|
||||
@@ -613,9 +593,7 @@ interface BillingDetailsViewProps {
|
||||
onViewPlans: () => void;
|
||||
onRefresh?: () => Promise<void>;
|
||||
isAirGapped?: boolean;
|
||||
isManualLicenseOnly?: boolean;
|
||||
hasStripeError?: boolean;
|
||||
licenseCard?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function BillingDetailsView({
|
||||
@@ -624,13 +602,10 @@ export default function BillingDetailsView({
|
||||
onViewPlans,
|
||||
onRefresh,
|
||||
isAirGapped,
|
||||
isManualLicenseOnly,
|
||||
hasStripeError,
|
||||
licenseCard,
|
||||
}: BillingDetailsViewProps) {
|
||||
const expirationState = billing ? getExpirationState(billing, license) : null;
|
||||
const disableBillingActions =
|
||||
isAirGapped || hasStripeError || isManualLicenseOnly;
|
||||
const disableBillingActions = isAirGapped || hasStripeError;
|
||||
|
||||
return (
|
||||
<Section gap={1} height="auto" width="full">
|
||||
@@ -647,7 +622,7 @@ export default function BillingDetailsView({
|
||||
)}
|
||||
|
||||
{/* Air-gapped mode info banner */}
|
||||
{isAirGapped && !hasStripeError && !isManualLicenseOnly && (
|
||||
{isAirGapped && !hasStripeError && (
|
||||
<Message
|
||||
static
|
||||
info
|
||||
@@ -690,21 +665,16 @@ export default function BillingDetailsView({
|
||||
license={license}
|
||||
onViewPlans={onViewPlans}
|
||||
disabled={disableBillingActions}
|
||||
isManualLicenseOnly={isManualLicenseOnly}
|
||||
onReconnect={onRefresh}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* License card (inline for manual license users) */}
|
||||
{licenseCard}
|
||||
|
||||
{/* Seats card */}
|
||||
<SeatsCard
|
||||
billing={billing}
|
||||
license={license}
|
||||
onRefresh={onRefresh}
|
||||
disabled={disableBillingActions}
|
||||
hideUpdateSeats={isManualLicenseOnly}
|
||||
/>
|
||||
|
||||
{/* Payment section */}
|
||||
|
||||
@@ -19,7 +19,6 @@ interface LicenseActivationCardProps {
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
license?: LicenseStatus;
|
||||
hideClose?: boolean;
|
||||
}
|
||||
|
||||
export default function LicenseActivationCard({
|
||||
@@ -27,7 +26,6 @@ export default function LicenseActivationCard({
|
||||
onClose,
|
||||
onSuccess,
|
||||
license,
|
||||
hideClose,
|
||||
}: LicenseActivationCardProps) {
|
||||
const [licenseKey, setLicenseKey] = useState("");
|
||||
const [isActivating, setIsActivating] = useState(false);
|
||||
@@ -122,11 +120,9 @@ export default function LicenseActivationCard({
|
||||
<Button main secondary onClick={() => setShowInput(true)}>
|
||||
Update Key
|
||||
</Button>
|
||||
{!hideClose && (
|
||||
<Button main tertiary onClick={handleClose}>
|
||||
Close
|
||||
</Button>
|
||||
)}
|
||||
<Button main tertiary onClick={handleClose}>
|
||||
Close
|
||||
</Button>
|
||||
</Section>
|
||||
</Section>
|
||||
</Card>
|
||||
|
||||
@@ -121,12 +121,11 @@ export default function BillingPage() {
|
||||
const billing = hasSubscription ? (billingData as BillingInformation) : null;
|
||||
const isSelfHosted = !NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
|
||||
// User is only air-gapped if they have a manual license AND Stripe is not connected
|
||||
// Once Stripe connects successfully, they're no longer air-gapped
|
||||
const hasManualLicense = licenseData?.source === "manual_upload";
|
||||
|
||||
// Air-gapped: billing endpoint is unreachable (manual license + connectivity error)
|
||||
const isAirGapped = !!(hasManualLicense && billingError);
|
||||
|
||||
// Stripe error: auto-fetched license but billing endpoint is unreachable
|
||||
const stripeConnected = billingData && !billingError;
|
||||
const isAirGapped = hasManualLicense && !stripeConnected;
|
||||
const hasStripeError = !!(
|
||||
isSelfHosted &&
|
||||
licenseData?.has_license &&
|
||||
@@ -134,10 +133,6 @@ export default function BillingPage() {
|
||||
!hasManualLicense
|
||||
);
|
||||
|
||||
// Manual license without active Stripe subscription
|
||||
// Stripe-dependent actions (manage plan, update seats) won't work
|
||||
const isManualLicenseOnly = !!(hasManualLicense && !hasSubscription);
|
||||
|
||||
// Set initial view based on subscription status (only once when data first loads)
|
||||
useEffect(() => {
|
||||
if (!isLoading && view === null) {
|
||||
@@ -248,10 +243,7 @@ export default function BillingPage() {
|
||||
return {
|
||||
icon: hasSubscription ? SvgWallet : SvgArrowUpCircle,
|
||||
title: hasSubscription ? "View Plans" : "Upgrade Plan",
|
||||
showBackButton: !!(
|
||||
hasSubscription ||
|
||||
(isSelfHosted && licenseData?.has_license)
|
||||
),
|
||||
showBackButton: !!hasSubscription,
|
||||
};
|
||||
case "details":
|
||||
return {
|
||||
@@ -279,11 +271,9 @@ export default function BillingPage() {
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
const hasEntitlement =
|
||||
hasSubscription || (isSelfHosted && licenseData?.has_license);
|
||||
if (view === "checkout") {
|
||||
changeView(hasEntitlement ? "details" : "plans");
|
||||
} else if (view === "plans" && hasEntitlement) {
|
||||
changeView(hasSubscription ? "details" : "plans");
|
||||
} else if (view === "plans" && hasSubscription) {
|
||||
changeView("details");
|
||||
}
|
||||
};
|
||||
@@ -315,19 +305,7 @@ export default function BillingPage() {
|
||||
onViewPlans={() => changeView("plans")}
|
||||
onRefresh={handleRefresh}
|
||||
isAirGapped={isAirGapped}
|
||||
isManualLicenseOnly={isManualLicenseOnly}
|
||||
hasStripeError={hasStripeError}
|
||||
licenseCard={
|
||||
isManualLicenseOnly ? (
|
||||
<LicenseActivationCard
|
||||
isOpen
|
||||
onSuccess={handleLicenseActivated}
|
||||
license={licenseData ?? undefined}
|
||||
onClose={() => {}}
|
||||
hideClose
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
),
|
||||
};
|
||||
@@ -344,7 +322,7 @@ export default function BillingPage() {
|
||||
if (isLoading || view === null) return null;
|
||||
return (
|
||||
<>
|
||||
{showLicenseActivationInput && !isManualLicenseOnly && (
|
||||
{showLicenseActivationInput && (
|
||||
<div className="w-full billing-card-enter">
|
||||
<LicenseActivationCard
|
||||
isOpen={showLicenseActivationInput}
|
||||
@@ -363,7 +341,6 @@ export default function BillingPage() {
|
||||
isSelfHosted ? () => setShowLicenseActivationInput(true) : undefined
|
||||
}
|
||||
hideLicenseLink={
|
||||
isManualLicenseOnly ||
|
||||
showLicenseActivationInput ||
|
||||
(view === "plans" &&
|
||||
(!!hasSubscription || !!licenseData?.has_license))
|
||||
|
||||
@@ -14,11 +14,6 @@ import {
|
||||
TimelineRendererComponent,
|
||||
TimelineRendererOutput,
|
||||
} from "./TimelineRendererComponent";
|
||||
import {
|
||||
isReasoningPackets,
|
||||
isDeepResearchPlanPackets,
|
||||
isMemoryToolPackets,
|
||||
} from "./packetHelpers";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { SvgBranch, SvgFold, SvgExpand } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
@@ -65,13 +60,6 @@ export function ParallelTimelineTabs({
|
||||
[turnGroup.steps, activeTab]
|
||||
);
|
||||
|
||||
// Determine if the active step needs full-width content (no right padding)
|
||||
const noPaddingRight = activeStep
|
||||
? isReasoningPackets(activeStep.packets) ||
|
||||
isDeepResearchPlanPackets(activeStep.packets) ||
|
||||
isMemoryToolPackets(activeStep.packets)
|
||||
: false;
|
||||
|
||||
// Memoized loading states for each step
|
||||
const loadingStates = useMemo(
|
||||
() =>
|
||||
@@ -94,10 +82,9 @@ export function ParallelTimelineTabs({
|
||||
isFirstStep={false}
|
||||
isSingleStep={false}
|
||||
collapsible={true}
|
||||
noPaddingRight={noPaddingRight}
|
||||
/>
|
||||
),
|
||||
[isLastTurnGroup, noPaddingRight]
|
||||
[isLastTurnGroup]
|
||||
);
|
||||
|
||||
const hasActivePackets = Boolean(activeStep && activeStep.packets.length > 0);
|
||||
|
||||
@@ -50,7 +50,7 @@ export function TimelineStepComposer({
|
||||
header={result.status}
|
||||
isExpanded={result.isExpanded}
|
||||
onToggle={result.onToggle}
|
||||
collapsible={collapsible && !isSingleStep}
|
||||
collapsible={collapsible}
|
||||
supportsCollapsible={result.supportsCollapsible}
|
||||
isLastStep={index === results.length - 1 && isLastStep}
|
||||
isFirstStep={index === 0 && isFirstStep}
|
||||
|
||||
@@ -63,7 +63,7 @@ export const FetchToolRenderer: MessageRenderer<FetchToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgCircle,
|
||||
status: "Reading",
|
||||
status: null,
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -46,7 +46,7 @@ export const MemoryToolRenderer: MessageRenderer<MemoryToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgEditBig,
|
||||
status: "Memory",
|
||||
status: null,
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -169,9 +169,7 @@ export const ReasoningRenderer: MessageRenderer<
|
||||
);
|
||||
|
||||
if (!hasStart && !hasEnd && content.length === 0) {
|
||||
return children([
|
||||
{ icon: SvgCircle, status: THINKING_STATUS, content: <></> },
|
||||
]);
|
||||
return children([{ icon: SvgCircle, status: null, content: <></> }]);
|
||||
}
|
||||
|
||||
const reasoningContent = (
|
||||
|
||||
@@ -61,7 +61,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
children,
|
||||
}) => {
|
||||
const searchState = constructCurrentSearchState(packets);
|
||||
const { queries, results, isComplete } = searchState;
|
||||
const { queries, results } = searchState;
|
||||
|
||||
const isCompact = renderType === RenderType.COMPACT;
|
||||
const isHighlight = renderType === RenderType.HIGHLIGHT;
|
||||
@@ -75,7 +75,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
return children([
|
||||
{
|
||||
icon: SvgSearchMenu,
|
||||
status: queriesHeader,
|
||||
status: null,
|
||||
content: <></>,
|
||||
supportsCollapsible: true,
|
||||
timelineLayout: "timeline",
|
||||
@@ -109,15 +109,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text04 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
/>
|
||||
</div>
|
||||
),
|
||||
@@ -172,15 +164,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text04 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
/>
|
||||
),
|
||||
},
|
||||
@@ -229,15 +213,7 @@ export const InternalSearchToolRenderer: MessageRenderer<
|
||||
window.open(doc.link, "_blank", "noopener,noreferrer");
|
||||
}
|
||||
}}
|
||||
emptyState={
|
||||
!isComplete ? (
|
||||
<BlinkingBar />
|
||||
) : (
|
||||
<Text as="p" text03 mainUiMuted>
|
||||
No results found
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
emptyState={!stopPacketSeen ? <BlinkingBar /> : undefined}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -53,7 +53,7 @@ export const WebSearchToolRenderer: MessageRenderer<SearchToolPacket, {}> = ({
|
||||
return children([
|
||||
{
|
||||
icon: SvgGlobe,
|
||||
status: "Searching the web",
|
||||
status: null,
|
||||
content: <div />,
|
||||
supportsCollapsible: false,
|
||||
timelineLayout: "timeline",
|
||||
|
||||
@@ -131,8 +131,7 @@ export async function updateAgentSharedStatus(
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean | undefined,
|
||||
isPaidEnterpriseFeaturesEnabled: boolean,
|
||||
labelIds?: number[]
|
||||
isPaidEnterpriseFeaturesEnabled: boolean
|
||||
): Promise<null | string> {
|
||||
// MIT versions should not send group_ids - warn if caller provided non-empty groups
|
||||
if (!isPaidEnterpriseFeaturesEnabled && groupIds.length > 0) {
|
||||
@@ -153,7 +152,6 @@ export async function updateAgentSharedStatus(
|
||||
// Only include group_ids for enterprise versions
|
||||
group_ids: isPaidEnterpriseFeaturesEnabled ? groupIds : undefined,
|
||||
is_public: isPublic,
|
||||
label_ids: labelIds,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -168,63 +166,3 @@ export async function updateAgentSharedStatus(
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the labels assigned to an agent via the share endpoint.
|
||||
*
|
||||
* @param agentId - The ID of the agent to update
|
||||
* @param labelIds - Array of label IDs to assign to the agent
|
||||
* @returns null on success, or an error message string on failure
|
||||
*/
|
||||
export async function updateAgentLabels(
|
||||
agentId: number,
|
||||
labelIds: number[]
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/persona/${agentId}/share`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ label_ids: labelIds }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage = (await response.json()).detail || "Unknown error";
|
||||
return errorMessage;
|
||||
} catch (error) {
|
||||
console.error("updateAgentLabels: Network error", error);
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the featured (default) status of an agent.
|
||||
*
|
||||
* @param agentId - The ID of the agent to update
|
||||
* @param isFeatured - Whether the agent should be featured
|
||||
* @returns null on success, or an error message string on failure
|
||||
*/
|
||||
export async function updateAgentFeaturedStatus(
|
||||
agentId: number,
|
||||
isFeatured: boolean
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ is_default_persona: isFeatured }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const errorMessage = (await response.json()).detail || "Unknown error";
|
||||
return errorMessage;
|
||||
} catch (error) {
|
||||
console.error("updateAgentFeaturedStatus: Network error", error);
|
||||
return "Network error. Please check your connection and try again.";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,27 +257,19 @@ export const useLabels = () => {
|
||||
return mutate("/api/persona/labels");
|
||||
};
|
||||
|
||||
const createLabel = async (name: string): Promise<PersonaLabel | null> => {
|
||||
const createLabel = async (name: string) => {
|
||||
const response = await fetch("/api/persona/labels", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
if (response.ok) {
|
||||
const newLabel = await response.json();
|
||||
mutate("/api/persona/labels", [...(labels || []), newLabel], false);
|
||||
}
|
||||
|
||||
const newLabel: PersonaLabel = await response.json();
|
||||
mutate(
|
||||
"/api/persona/labels",
|
||||
(currentLabels: PersonaLabel[] | undefined) => [
|
||||
...(currentLabels || []),
|
||||
newLabel,
|
||||
],
|
||||
false
|
||||
);
|
||||
return newLabel;
|
||||
return response;
|
||||
};
|
||||
|
||||
const updateLabel = async (id: number, name: string) => {
|
||||
|
||||
@@ -1,51 +1,29 @@
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
export interface ChipProps {
|
||||
children?: string;
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
onRemove?: () => void;
|
||||
smallLabel?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* A simple chip/tag component for displaying metadata.
|
||||
* Supports an optional remove button via the `onRemove` prop.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Chip>Tag Name</Chip>
|
||||
* <Chip icon={SvgUser}>John Doe</Chip>
|
||||
* <Chip onRemove={() => removeTag(id)}>Removable</Chip>
|
||||
* ```
|
||||
*/
|
||||
export default function Chip({
|
||||
children,
|
||||
icon: Icon,
|
||||
onRemove,
|
||||
smallLabel = true,
|
||||
}: ChipProps) {
|
||||
export default function Chip({ children, icon: Icon }: ChipProps) {
|
||||
return (
|
||||
<div className="flex items-center gap-1 px-1.5 py-0.5 rounded-08 bg-background-tint-02">
|
||||
{Icon && <Icon size={12} className="text-text-03" />}
|
||||
{children && (
|
||||
<Text figureSmallLabel={smallLabel} text03>
|
||||
<Text figureSmallLabel text03>
|
||||
{children}
|
||||
</Text>
|
||||
)}
|
||||
{onRemove && (
|
||||
<Button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRemove();
|
||||
}}
|
||||
prominence="tertiary"
|
||||
icon={SvgX}
|
||||
size="xs"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Chip from "@/refresh-components/Chip";
|
||||
import {
|
||||
innerClasses,
|
||||
textClasses,
|
||||
Variants,
|
||||
wrapperClasses,
|
||||
} from "@/refresh-components/inputs/styles";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
export interface ChipItem {
|
||||
id: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export interface InputChipFieldProps {
|
||||
chips: ChipItem[];
|
||||
onRemoveChip: (id: string) => void;
|
||||
onAdd: (value: string) => void;
|
||||
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
|
||||
placeholder?: string;
|
||||
disabled?: boolean;
|
||||
variant?: Variants;
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A tag/chip input field that renders chips inline alongside a text input.
|
||||
*
|
||||
* Pressing Enter adds a chip via `onAdd`. Pressing Backspace on an empty
|
||||
* input removes the last chip. Each chip has a remove button.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <InputChipField
|
||||
* chips={[{ id: "1", label: "Search" }]}
|
||||
* onRemoveChip={(id) => remove(id)}
|
||||
* onAdd={(value) => add(value)}
|
||||
* value={inputValue}
|
||||
* onChange={setInputValue}
|
||||
* placeholder="Add labels..."
|
||||
* icon={SvgTag}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
function InputChipField({
|
||||
chips,
|
||||
onRemoveChip,
|
||||
onAdd,
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
disabled = false,
|
||||
variant = "primary",
|
||||
icon: Icon,
|
||||
className,
|
||||
}: InputChipFieldProps) {
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
function handleKeyDown(e: React.KeyboardEvent<HTMLInputElement>) {
|
||||
if (disabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (e.key === "Enter") {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
const trimmed = value.trim();
|
||||
if (trimmed) {
|
||||
onAdd(trimmed);
|
||||
}
|
||||
}
|
||||
if (e.key === "Backspace" && value === "") {
|
||||
const lastChip = chips[chips.length - 1];
|
||||
if (lastChip) {
|
||||
onRemoveChip(lastChip.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-row items-center flex-wrap gap-1 p-1.5 rounded-08 cursor-text w-full",
|
||||
wrapperClasses[variant],
|
||||
className
|
||||
)}
|
||||
onClick={() => inputRef.current?.focus()}
|
||||
>
|
||||
{Icon && <Icon size={16} className="text-text-04 shrink-0" />}
|
||||
{chips.map((chip) => (
|
||||
<Chip
|
||||
key={chip.id}
|
||||
onRemove={disabled ? undefined : () => onRemoveChip(chip.id)}
|
||||
smallLabel={false}
|
||||
>
|
||||
{chip.label}
|
||||
</Chip>
|
||||
))}
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
disabled={disabled}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={chips.length === 0 ? placeholder : undefined}
|
||||
className={cn(
|
||||
"flex-1 min-w-[80px] h-[1.5rem] bg-transparent p-0.5 focus:outline-none",
|
||||
innerClasses[variant],
|
||||
textClasses[variant]
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default InputChipField;
|
||||
@@ -64,7 +64,6 @@ function MemoryItem({
|
||||
if (!shouldHighlight) return;
|
||||
|
||||
wrapperRef.current?.scrollIntoView({ block: "center", behavior: "smooth" });
|
||||
textareaRef.current?.focus();
|
||||
setIsHighlighting(true);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
|
||||
@@ -115,9 +115,14 @@ export default function ActionLineItem({
|
||||
<Section gap={0.25} flexDirection="row">
|
||||
{!isUnavailable && tool?.oauth_config_id && toolAuthStatus && (
|
||||
<Button
|
||||
icon={SvgKey}
|
||||
prominence="secondary"
|
||||
size="sm"
|
||||
icon={({ className }) => (
|
||||
<SvgKey
|
||||
className={cn(
|
||||
className,
|
||||
"stroke-yellow-500 hover:stroke-yellow-600"
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
onClick={noProp(() => {
|
||||
if (
|
||||
!toolAuthStatus.hasToken ||
|
||||
|
||||
@@ -651,8 +651,6 @@ export default function AgentEditorPage({
|
||||
shared_user_ids: existingAgent?.users?.map((user) => user.id) ?? [],
|
||||
shared_group_ids: existingAgent?.groups ?? [],
|
||||
is_public: existingAgent?.is_public ?? true,
|
||||
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
|
||||
is_default_persona: existingAgent?.is_default_persona ?? false,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object().shape({
|
||||
@@ -814,8 +812,8 @@ export default function AgentEditorPage({
|
||||
uploaded_image_id: values.uploaded_image_id,
|
||||
icon_name: values.icon_name,
|
||||
search_start_date: values.knowledge_cutoff_date || null,
|
||||
label_ids: values.label_ids,
|
||||
is_default_persona: values.is_default_persona,
|
||||
label_ids: null,
|
||||
is_default_persona: false,
|
||||
// display_priority: ...,
|
||||
|
||||
user_file_ids: values.enable_knowledge ? values.user_file_ids : [],
|
||||
@@ -1056,20 +1054,10 @@ export default function AgentEditorPage({
|
||||
userIds={values.shared_user_ids}
|
||||
groupIds={values.shared_group_ids}
|
||||
isPublic={values.is_public}
|
||||
isFeatured={values.is_default_persona}
|
||||
labelIds={values.label_ids}
|
||||
onShare={(
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isFeatured,
|
||||
labelIds
|
||||
) => {
|
||||
onShare={(userIds, groupIds, isPublic) => {
|
||||
setFieldValue("shared_user_ids", userIds);
|
||||
setFieldValue("shared_group_ids", groupIds);
|
||||
setFieldValue("is_public", isPublic);
|
||||
setFieldValue("is_default_persona", isFeatured);
|
||||
setFieldValue("label_ids", labelIds);
|
||||
shareAgentModal.toggle(false);
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -11,11 +11,7 @@ import { cn, noProp } from "@/lib/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { Route } from "next";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import {
|
||||
checkUserOwnsAssistant,
|
||||
updateAgentSharedStatus,
|
||||
updateAgentFeaturedStatus,
|
||||
} from "@/lib/agents";
|
||||
import { checkUserOwnsAssistant, updateAgentSharedStatus } from "@/lib/agents";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
SvgActions,
|
||||
@@ -47,9 +43,8 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
() => pinnedAgents.some((pinnedAgent) => pinnedAgent.id === agent.id),
|
||||
[agent.id, pinnedAgents]
|
||||
);
|
||||
const { user, isAdmin, isCurator } = useUser();
|
||||
const { user } = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const canUpdateFeaturedStatus = isAdmin || isCurator;
|
||||
const isOwnedByUser = checkUserOwnsAssistant(user, agent);
|
||||
const shareAgentModal = useCreateModal();
|
||||
const agentViewerModal = useCreateModal();
|
||||
@@ -63,49 +58,26 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
route({ agentId: agent.id });
|
||||
}, [pinned, togglePinnedAgent, agent, route]);
|
||||
|
||||
// Handle sharing agent
|
||||
const handleShare = useCallback(
|
||||
async (
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean,
|
||||
isFeatured: boolean,
|
||||
labelIds: number[]
|
||||
) => {
|
||||
const shareError = await updateAgentSharedStatus(
|
||||
async (userIds: string[], groupIds: number[], isPublic: boolean) => {
|
||||
const error = await updateAgentSharedStatus(
|
||||
agent.id,
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isPaidEnterpriseFeaturesEnabled,
|
||||
labelIds
|
||||
isPaidEnterpriseFeaturesEnabled
|
||||
);
|
||||
|
||||
if (shareError) {
|
||||
toast.error(`Failed to share agent: ${shareError}`);
|
||||
return;
|
||||
if (error) {
|
||||
toast.error(`Failed to share agent: ${error}`);
|
||||
} else {
|
||||
// Revalidate the agent data to reflect the changes
|
||||
refreshAgent();
|
||||
shareAgentModal.toggle(false);
|
||||
}
|
||||
|
||||
if (canUpdateFeaturedStatus) {
|
||||
const featuredError = await updateAgentFeaturedStatus(
|
||||
agent.id,
|
||||
isFeatured
|
||||
);
|
||||
if (featuredError) {
|
||||
toast.error(`Failed to update featured status: ${featuredError}`);
|
||||
refreshAgent();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
refreshAgent();
|
||||
shareAgentModal.toggle(false);
|
||||
},
|
||||
[
|
||||
agent.id,
|
||||
canUpdateFeaturedStatus,
|
||||
isPaidEnterpriseFeaturesEnabled,
|
||||
refreshAgent,
|
||||
]
|
||||
[agent.id, isPaidEnterpriseFeaturesEnabled, refreshAgent]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -116,8 +88,6 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
|
||||
groupIds={fullAgent?.groups ?? []}
|
||||
isPublic={fullAgent?.is_public ?? false}
|
||||
isFeatured={fullAgent?.is_default_persona ?? false}
|
||||
labelIds={fullAgent?.labels?.map((l) => l.id) ?? []}
|
||||
onShare={handleShare}
|
||||
/>
|
||||
</shareAgentModal.Provider>
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useMemo, useState } from "react";
|
||||
import { useMemo } from "react";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import {
|
||||
SvgLink,
|
||||
SvgOrganization,
|
||||
SvgShare,
|
||||
SvgTag,
|
||||
SvgUsers,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import InputChipField from "@/refresh-components/inputs/InputChipField";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
@@ -28,8 +26,6 @@ import { useUser } from "@/providers/UserProvider";
|
||||
import { Formik, useFormikContext } from "formik";
|
||||
import { useAgent } from "@/hooks/useAgents";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { useLabels } from "@/lib/hooks";
|
||||
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
const YOUR_ORGANIZATION_TAB = "Your Organization";
|
||||
const USERS_AND_GROUPS_TAB = "Users & Groups";
|
||||
@@ -42,8 +38,6 @@ interface ShareAgentFormValues {
|
||||
selectedUserIds: string[];
|
||||
selectedGroupIds: number[];
|
||||
isPublic: boolean;
|
||||
isFeatured: boolean;
|
||||
labelIds: number[];
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -59,15 +53,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
useFormikContext<ShareAgentFormValues>();
|
||||
const { data: usersData } = useShareableUsers({ includeApiKeys: true });
|
||||
const { data: groupsData } = useShareableGroups();
|
||||
const { user: currentUser, isAdmin, isCurator } = useUser();
|
||||
const { user: currentUser } = useUser();
|
||||
const { agent: fullAgent } = useAgent(agentId ?? null);
|
||||
const shareAgentModal = useModal();
|
||||
const { labels: allLabels, createLabel } = useLabels();
|
||||
const [labelInputValue, setLabelInputValue] = useState("");
|
||||
|
||||
const acceptedUsers = usersData ?? [];
|
||||
const groups = groupsData ?? [];
|
||||
const canUpdateFeaturedStatus = isAdmin || isCurator;
|
||||
|
||||
// Create options for InputComboBox from all accepted users and groups
|
||||
const comboBoxOptions = useMemo(() => {
|
||||
@@ -146,50 +137,6 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
);
|
||||
}
|
||||
|
||||
const selectedLabels: PersonaLabel[] = useMemo(() => {
|
||||
if (!allLabels) return [];
|
||||
return allLabels.filter((label) => values.labelIds.includes(label.id));
|
||||
}, [allLabels, values.labelIds]);
|
||||
|
||||
function handleRemoveLabel(labelId: number) {
|
||||
setFieldValue(
|
||||
"labelIds",
|
||||
values.labelIds.filter((id) => id !== labelId)
|
||||
);
|
||||
}
|
||||
|
||||
const addLabel = useCallback(
|
||||
async (name: string) => {
|
||||
const trimmed = name.trim();
|
||||
if (!trimmed) return;
|
||||
|
||||
const existing = allLabels?.find(
|
||||
(l) => l.name.toLowerCase() === trimmed.toLowerCase()
|
||||
);
|
||||
if (existing) {
|
||||
if (!values.labelIds.includes(existing.id)) {
|
||||
setFieldValue("labelIds", [...values.labelIds, existing.id]);
|
||||
}
|
||||
} else {
|
||||
const newLabel = await createLabel(trimmed);
|
||||
if (newLabel) {
|
||||
setFieldValue("labelIds", [...values.labelIds, newLabel.id]);
|
||||
}
|
||||
}
|
||||
setLabelInputValue("");
|
||||
},
|
||||
[allLabels, values.labelIds, setFieldValue, createLabel]
|
||||
);
|
||||
|
||||
const chipItems = useMemo(
|
||||
() =>
|
||||
selectedLabels.map((label) => ({
|
||||
id: String(label.id),
|
||||
label: label.name,
|
||||
})),
|
||||
[selectedLabels]
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header icon={SvgShare} title="Share Agent" onClose={handleClose} />
|
||||
@@ -286,41 +233,12 @@ function ShareAgentFormContent({ agentId }: ShareAgentFormContentProps) {
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={YOUR_ORGANIZATION_TAB} padding={0.5}>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<InputLayouts.Horizontal
|
||||
title="Publish This Agent"
|
||||
description="Make this agent available to everyone in your organization."
|
||||
>
|
||||
<SwitchField name="isPublic" />
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{canUpdateFeaturedStatus && (
|
||||
<>
|
||||
<div className="border-t border-border-02" />
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Feature This Agent"
|
||||
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
|
||||
>
|
||||
<SwitchField name="isFeatured" />
|
||||
</InputLayouts.Horizontal>
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputChipField
|
||||
chips={chipItems}
|
||||
onRemoveChip={(id) => handleRemoveLabel(Number(id))}
|
||||
onAdd={addLabel}
|
||||
value={labelInputValue}
|
||||
onChange={setLabelInputValue}
|
||||
placeholder="Add labels..."
|
||||
icon={SvgTag}
|
||||
/>
|
||||
<Text secondaryBody text04>
|
||||
Add labels and categories to help people better discover this
|
||||
agent.
|
||||
</Text>
|
||||
</Section>
|
||||
<InputLayouts.Horizontal
|
||||
title="Publish This Agent"
|
||||
description="Make this agent available to everyone in your organization."
|
||||
>
|
||||
<SwitchField name="isPublic" />
|
||||
</InputLayouts.Horizontal>
|
||||
</Tabs.Content>
|
||||
</Tabs>
|
||||
</Card>
|
||||
@@ -360,15 +278,7 @@ export interface ShareAgentModalProps {
|
||||
userIds: string[];
|
||||
groupIds: number[];
|
||||
isPublic: boolean;
|
||||
isFeatured: boolean;
|
||||
labelIds: number[];
|
||||
onShare?: (
|
||||
userIds: string[],
|
||||
groupIds: number[],
|
||||
isPublic: boolean,
|
||||
isFeatured: boolean,
|
||||
labelIds: number[]
|
||||
) => void;
|
||||
onShare?: (userIds: string[], groupIds: number[], isPublic: boolean) => void;
|
||||
}
|
||||
|
||||
export default function ShareAgentModal({
|
||||
@@ -376,8 +286,6 @@ export default function ShareAgentModal({
|
||||
userIds,
|
||||
groupIds,
|
||||
isPublic,
|
||||
isFeatured,
|
||||
labelIds,
|
||||
onShare,
|
||||
}: ShareAgentModalProps) {
|
||||
const shareAgentModal = useModal();
|
||||
@@ -386,18 +294,10 @@ export default function ShareAgentModal({
|
||||
selectedUserIds: userIds,
|
||||
selectedGroupIds: groupIds,
|
||||
isPublic: isPublic,
|
||||
isFeatured: isFeatured,
|
||||
labelIds: labelIds,
|
||||
};
|
||||
|
||||
function handleSubmit(values: ShareAgentFormValues) {
|
||||
onShare?.(
|
||||
values.selectedUserIds,
|
||||
values.selectedGroupIds,
|
||||
values.isPublic,
|
||||
values.isFeatured,
|
||||
values.labelIds
|
||||
);
|
||||
onShare?.(values.selectedUserIds, values.selectedGroupIds, values.isPublic);
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user