mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
24 Commits
v2.9.6
...
test_searc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4554aedbea | ||
|
|
d605542e8d | ||
|
|
4051e51e54 | ||
|
|
452c907058 | ||
|
|
168b32e400 | ||
|
|
5b736b7f29 | ||
|
|
04ed1ff155 | ||
|
|
2ddbf27a9e | ||
|
|
518e5b5351 | ||
|
|
5a5e06d373 | ||
|
|
f5acd082b8 | ||
|
|
39b5830184 | ||
|
|
a3f116eb6a | ||
|
|
6fd495f905 | ||
|
|
d6a43708bd | ||
|
|
78e7ad59e9 | ||
|
|
cbdc163597 | ||
|
|
9f658176a0 | ||
|
|
2928fbd3be | ||
|
|
bea17000e1 | ||
|
|
f6018de7b0 | ||
|
|
936f3433e3 | ||
|
|
2522d685a6 | ||
|
|
27633ac0fc |
5
.github/workflows/run-it.yml
vendored
5
.github/workflows/run-it.yml
vendored
@@ -10,6 +10,11 @@ on:
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
|
||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
|
||||
Binary file not shown.
@@ -113,7 +113,7 @@ def set_new_search_settings(
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
return IdReturn(search_settings_id=new_search_settings.id, index_name=index_name)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding")
|
||||
|
||||
@@ -23,7 +23,8 @@ class ApiKey(BaseModel):
|
||||
|
||||
|
||||
class IdReturn(BaseModel):
|
||||
id: int
|
||||
search_settings_id: int
|
||||
index_name: str
|
||||
|
||||
|
||||
class MinimalUserSnapshot(BaseModel):
|
||||
|
||||
@@ -5,10 +5,10 @@ The integration tests are designed with a "manager" class and a "test" class for
|
||||
- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity.
|
||||
- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object.
|
||||
|
||||
The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function.
|
||||
The idea is that each test can use the manager class to create (`.create()`) a `test_` object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the `test_` object is in the expected state by using the manager class (`.verify()`) function.
|
||||
|
||||
## Instructions for Running Integration Tests Locally
|
||||
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080.
|
||||
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080 with basic AUTH TYPE enabled.
|
||||
a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory.
|
||||
2. Navigate to `danswer/backend`.
|
||||
3. Run the following command in the terminal:
|
||||
|
||||
@@ -132,7 +132,10 @@ class DocumentSetManager:
|
||||
if not check_ids.issubset(doc_set_ids):
|
||||
raise RuntimeError("Document set not found")
|
||||
doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids]
|
||||
all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets)
|
||||
all_up_to_date = True
|
||||
for doc_set in doc_sets:
|
||||
if not doc_set.is_up_to_date:
|
||||
all_up_to_date = False
|
||||
|
||||
if all_up_to_date:
|
||||
break
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
import requests
|
||||
|
||||
from danswer.db.models import EmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.search_settings import (
|
||||
SearchSettingsManager,
|
||||
)
|
||||
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class EmbeddingProviderManager:
|
||||
@staticmethod
|
||||
def test(
|
||||
user_performing_action: TestUser,
|
||||
embedding_provider: TestCloudEmbeddingProvider,
|
||||
model_name: str | None,
|
||||
) -> None:
|
||||
test_embedding_request = TestEmbeddingRequest(
|
||||
provider_type=embedding_provider.provider_type,
|
||||
api_key=embedding_provider.api_key,
|
||||
api_url=embedding_provider.api_url,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/admin/embedding/test-embedding",
|
||||
json=test_embedding_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to test embedding provider: {response.json()}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
test_embedding_provider: TestCloudEmbeddingProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCloudEmbeddingProvider:
|
||||
embedding_provider_request = {
|
||||
"provider_type": test_embedding_provider.provider_type,
|
||||
"api_url": test_embedding_provider.api_url,
|
||||
"api_key": test_embedding_provider.api_key,
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
|
||||
json=embedding_provider_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
return TestCloudEmbeddingProvider(
|
||||
provider_type=response_data["provider_type"],
|
||||
api_key=response_data.get("api_key"),
|
||||
api_url=response_data.get("api_url"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestCloudEmbeddingProvider]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return [
|
||||
TestCloudEmbeddingProvider(**embedding_provider)
|
||||
for embedding_provider in response.json()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
embedding_provider: TestCloudEmbeddingProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
|
||||
json=embedding_provider.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
provider_type: EmbeddingProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider/{provider_type}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def verify_providers(
|
||||
embedding_providers: list[TestCloudEmbeddingProvider],
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
current_providers = EmbeddingProviderManager.get_all(user_performing_action)
|
||||
|
||||
for expected_provider in embedding_providers:
|
||||
matching_provider = next(
|
||||
(
|
||||
p
|
||||
for p in current_providers
|
||||
if p.provider_type == expected_provider.provider_type
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if matching_provider is None:
|
||||
raise ValueError(
|
||||
f"Embedding provider {expected_provider.provider_type} not found in current providers"
|
||||
)
|
||||
|
||||
if matching_provider.api_url != expected_provider.api_url:
|
||||
raise ValueError(
|
||||
f"API URL mismatch for provider {expected_provider.provider_type}"
|
||||
)
|
||||
|
||||
if matching_provider.api_key != expected_provider.api_key:
|
||||
raise ValueError(
|
||||
f"API Key mismatch for provider {expected_provider.provider_type}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
embedding_provider: TestCloudEmbeddingProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
current_settings = SearchSettingsManager.get_primary(user_performing_action)
|
||||
if current_settings is None:
|
||||
raise ValueError("No current embedding provider found")
|
||||
|
||||
current_provider_type = current_settings.provider_type
|
||||
|
||||
if current_provider_type is None:
|
||||
raise ValueError("No current embedding provider found")
|
||||
|
||||
if current_provider_type != embedding_provider.provider_type:
|
||||
raise ValueError(
|
||||
f"Current embedding provider {current_provider_type} does not match expected {embedding_provider.provider_type}"
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import TestFullModelVersionResponse
|
||||
from tests.integration.common_utils.test_models import TestSearchSettings
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class SearchSettingsManager:
|
||||
@staticmethod
|
||||
def create_and_set(
|
||||
model_name: str = "test-model",
|
||||
model_dim: int = 768,
|
||||
normalize: bool = True,
|
||||
query_prefix: str | None = "",
|
||||
passage_prefix: str | None = "",
|
||||
index_name: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
multipass_indexing: bool = False,
|
||||
multilingual_expansion: list[str] = [],
|
||||
disable_rerank_for_streaming: bool = False,
|
||||
rerank_model_name: str | None = None,
|
||||
rerank_provider_type: str | None = None,
|
||||
rerank_api_key: str | None = None,
|
||||
num_rerank: int = 50,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestSearchSettings:
|
||||
search_settings_request = {
|
||||
"model_name": model_name,
|
||||
"model_dim": model_dim,
|
||||
"normalize": normalize,
|
||||
"query_prefix": query_prefix,
|
||||
"passage_prefix": passage_prefix,
|
||||
"index_name": index_name,
|
||||
"provider_type": provider_type,
|
||||
"multipass_indexing": multipass_indexing,
|
||||
"multilingual_expansion": multilingual_expansion,
|
||||
"disable_rerank_for_streaming": disable_rerank_for_streaming,
|
||||
"rerank_model_name": rerank_model_name,
|
||||
"rerank_provider_type": rerank_provider_type,
|
||||
"rerank_api_key": rerank_api_key,
|
||||
"num_rerank": num_rerank,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/search-settings/set-new-search-settings",
|
||||
json=search_settings_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Merge the response data with the original request data
|
||||
merged_data = {**search_settings_request, **response_data}
|
||||
|
||||
return TestSearchSettings(**merged_data)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
new_primary_settings: TestSearchSettings,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
current_primary_settings = SearchSettingsManager.get_primary(
|
||||
user_performing_action
|
||||
)
|
||||
if (
|
||||
current_primary_settings.model_dump()
|
||||
== new_primary_settings.model_dump()
|
||||
):
|
||||
break
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"Search settings were not synced within the {MAX_DELAY} seconds"
|
||||
)
|
||||
else:
|
||||
print("Search settings were not synced yet, waiting...")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
search_settings: TestSearchSettings,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/search-settings/update-inference-settings",
|
||||
json=search_settings.model_dump(exclude={"id"}),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_primary(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestSearchSettings:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/search-settings/get-current-search-settings",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return TestSearchSettings(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_secondary(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestSearchSettings | None:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/search-settings/get-secondary-search-settings",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return TestSearchSettings(**data.json()) if data else None
|
||||
|
||||
@staticmethod
|
||||
def get_all_current(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestFullModelVersionResponse:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/search-settings/get-all-search-settings",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return TestFullModelVersionResponse(
|
||||
current_settings=TestSearchSettings(**data["current_settings"]),
|
||||
secondary_settings=TestSearchSettings(**data["secondary_settings"])
|
||||
if data["secondary_settings"]
|
||||
else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
search_settings: TestSearchSettings,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
current_settings = SearchSettingsManager.get_primary(user_performing_action)
|
||||
if current_settings.model_dump() != search_settings.model_dump():
|
||||
raise ValueError("Current search settings do not match expected settings")
|
||||
@@ -8,6 +8,7 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from danswer.server.documents.models import InputType
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
"""
|
||||
These data models are used to represent the data on the testing side of things.
|
||||
@@ -120,6 +121,39 @@ class TestPersona(BaseModel):
|
||||
groups: list[int]
|
||||
|
||||
|
||||
class TestSearchSettings(BaseModel):
|
||||
model_name: str
|
||||
model_dim: int
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
index_name: str
|
||||
provider_type: str | None
|
||||
multipass_indexing: bool
|
||||
multilingual_expansion: list[str]
|
||||
disable_rerank_for_streaming: bool
|
||||
rerank_model_name: str | None
|
||||
rerank_provider_type: str | None
|
||||
rerank_api_key: str | None = None
|
||||
num_rerank: int
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, TestSearchSettings):
|
||||
return False
|
||||
return self.dict() == other.dict()
|
||||
|
||||
|
||||
class TestCloudEmbeddingProvider(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None
|
||||
api_url: str | None
|
||||
|
||||
|
||||
class TestFullModelVersionResponse(BaseModel):
|
||||
current_settings: TestSearchSettings | None
|
||||
secondary_settings: TestSearchSettings | None
|
||||
|
||||
|
||||
#
|
||||
class TestChatSession(BaseModel):
|
||||
id: int
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
This file tests search settings creation, upgrades, and embedding provider management.
|
||||
"""
|
||||
import os
|
||||
|
||||
from danswer.db.models import EmbeddingProvider
|
||||
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
|
||||
from tests.integration.tests.search_settings.utils import (
|
||||
create_and_test_embedding_provider,
|
||||
)
|
||||
|
||||
|
||||
def test_creating_openai_embedding_provider(reset: None) -> None:
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
test_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
api_key=openai_api_key,
|
||||
api_url=None,
|
||||
)
|
||||
create_and_test_embedding_provider(test_embedding_provider)
|
||||
|
||||
|
||||
def test_creating_cohere_embedding_provider(reset: None) -> None:
|
||||
cohere_api_key = os.getenv("COHERE_API_KEY")
|
||||
test_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=cohere_api_key,
|
||||
api_url=None,
|
||||
)
|
||||
create_and_test_embedding_provider(test_embedding_provider)
|
||||
|
||||
|
||||
def test_creating_voyage_embedding_provider(reset: None) -> None:
|
||||
voyage_api_key = os.getenv("VOYAGE_API_KEY")
|
||||
test_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.VOYAGE,
|
||||
api_key=voyage_api_key,
|
||||
api_url=None,
|
||||
)
|
||||
create_and_test_embedding_provider(test_embedding_provider)
|
||||
|
||||
|
||||
def test_creating_google_embedding_provider(reset: None) -> None:
|
||||
google_api_key = os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
test_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.GOOGLE,
|
||||
api_key=google_api_key,
|
||||
api_url=None,
|
||||
)
|
||||
create_and_test_embedding_provider(test_embedding_provider)
|
||||
|
||||
|
||||
def test_creating_litellm_embedding_provider(reset: None) -> None:
|
||||
litellm_api_key = os.getenv("LITELLM_API_KEY")
|
||||
litellm_api_url = os.getenv("LITELLM_API_URL")
|
||||
litellm_test_model_name = os.getenv("LITELLM_TEST_MODEL_NAME")
|
||||
|
||||
test_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.LITELLM,
|
||||
api_key=litellm_api_key,
|
||||
api_url=litellm_api_url,
|
||||
)
|
||||
|
||||
create_and_test_embedding_provider(
|
||||
test_embedding_provider,
|
||||
model_name=litellm_test_model_name,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
This file tests search settings creation, upgrades, and embedding provider management.
|
||||
"""
|
||||
import os
|
||||
|
||||
from danswer.db.models import EmbeddingProvider
|
||||
from tests.integration.common_utils.managers.embedding_provider import (
|
||||
EmbeddingProviderManager,
|
||||
)
|
||||
from tests.integration.common_utils.managers.search_settings import (
|
||||
SearchSettingsManager,
|
||||
)
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def test_creating_and_upgrading_search_settings(reset: None) -> None:
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is not set")
|
||||
|
||||
current_embedding_provider = TestCloudEmbeddingProvider(
|
||||
provider_type=EmbeddingProvider.OPENAI, api_key=openai_api_key, api_url=None
|
||||
)
|
||||
|
||||
EmbeddingProviderManager.create(
|
||||
test_embedding_provider=current_embedding_provider,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
initial_settings = SearchSettingsManager.create_and_set(
|
||||
model_name="text-embedding-3-small",
|
||||
model_dim=1536,
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
SearchSettingsManager.wait_for_sync(
|
||||
new_primary_settings=initial_settings, user_performing_action=admin_user
|
||||
)
|
||||
SearchSettingsManager.verify(initial_settings, user_performing_action=admin_user)
|
||||
|
||||
new_settings = SearchSettingsManager.create_and_set(
|
||||
model_name="text-embedding-3-small",
|
||||
model_dim=1536,
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
SearchSettingsManager.wait_for_sync(
|
||||
new_primary_settings=new_settings, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
SearchSettingsManager.verify(new_settings, user_performing_action=admin_user)
|
||||
|
||||
EmbeddingProviderManager.verify(
|
||||
current_embedding_provider, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
new_api_key = "new-openai-api-key-example"
|
||||
|
||||
current_embedding_provider.api_key = new_api_key
|
||||
|
||||
EmbeddingProviderManager.edit(
|
||||
current_embedding_provider, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
EmbeddingProviderManager.verify(
|
||||
current_embedding_provider, user_performing_action=admin_user
|
||||
)
|
||||
32
backend/tests/integration/tests/search_settings/utils.py
Normal file
32
backend/tests/integration/tests/search_settings/utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
This module contains utility functions and fixtures for testing search settings.
|
||||
"""
|
||||
from tests.integration.common_utils.managers.embedding_provider import (
|
||||
EmbeddingProviderManager,
|
||||
)
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def create_and_test_embedding_provider(
|
||||
test_embedding_provider: TestCloudEmbeddingProvider,
|
||||
model_name: str | None = None,
|
||||
) -> None:
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
EmbeddingProviderManager.test(
|
||||
user_performing_action=admin_user,
|
||||
embedding_provider=test_embedding_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
EmbeddingProviderManager.create(
|
||||
test_embedding_provider=test_embedding_provider,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
EmbeddingProviderManager.verify_providers(
|
||||
embedding_providers=[test_embedding_provider],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
Reference in New Issue
Block a user