1
0
forked from github/onyx

Compare commits

...

24 Commits

Author SHA1 Message Date
pablodanswer
4554aedbea rebase + remove log 2024-09-15 18:46:56 -07:00
pablodanswer
d605542e8d update so passes 2024-09-15 18:46:01 -07:00
pablodanswer
4051e51e54 minor settings update 2024-09-15 18:46:01 -07:00
pablodanswer
452c907058 rename 2024-09-15 18:46:01 -07:00
pablodanswer
168b32e400 update ports + pass tests 2024-09-15 18:46:01 -07:00
pablodanswer
5b736b7f29 (squash) 2024-09-15 18:46:01 -07:00
pablodanswer
04ed1ff155 add github config 2024-09-15 18:46:01 -07:00
pablodanswer
2ddbf27a9e temp 2024-09-15 18:46:01 -07:00
pablodanswer
518e5b5351 updated tests 2024-09-15 18:46:01 -07:00
pablodanswer
5a5e06d373 update typing 2024-09-15 18:46:01 -07:00
pablodanswer
f5acd082b8 valid search setting + embedding model testing 2024-09-15 18:46:01 -07:00
pablodanswer
39b5830184 remove logs 2024-09-15 18:46:01 -07:00
pablodanswer
a3f116eb6a squash 2024-09-15 18:46:01 -07:00
pablodanswer
6fd495f905 update ports 2024-09-15 18:46:01 -07:00
pablodanswer
d6a43708bd update for validation 2024-09-15 18:46:01 -07:00
pablodanswer
78e7ad59e9 update typing 2024-09-15 18:46:01 -07:00
pablodanswer
cbdc163597 temporary api update 2024-09-15 18:46:01 -07:00
pablodanswer
9f658176a0 upadate naming for clarity 2024-09-15 18:46:01 -07:00
pablodanswer
2928fbd3be update embedding provider 2024-09-15 18:46:01 -07:00
pablodanswer
bea17000e1 validated tests for cloud embedding provider creation 2024-09-15 18:46:01 -07:00
pablodanswer
f6018de7b0 add proper testing models 2024-09-15 18:46:00 -07:00
pablodanswer
936f3433e3 add search settings 2024-09-15 18:45:54 -07:00
pablodanswer
2522d685a6 initial non-functioning search settings 2024-09-15 18:45:54 -07:00
pablodanswer
27633ac0fc initial integration test steup 2024-09-15 18:45:52 -07:00
12 changed files with 537 additions and 5 deletions

View File

@@ -10,6 +10,11 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
jobs:
integration-tests:

View File

@@ -113,7 +113,7 @@ def set_new_search_settings(
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
return IdReturn(id=new_search_settings.id)
return IdReturn(search_settings_id=new_search_settings.id, index_name=index_name)
@router.post("/cancel-new-embedding")

View File

@@ -23,7 +23,8 @@ class ApiKey(BaseModel):
class IdReturn(BaseModel):
id: int
search_settings_id: int
index_name: str
class MinimalUserSnapshot(BaseModel):

View File

@@ -5,10 +5,10 @@ The integration tests are designed with a "manager" class and a "test" class for
- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity.
- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object.
The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function.
The idea is that each test can use the manager class to create (`.create()`) a `test_` object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the `test_` object is in the expected state by using the manager class (`.verify()`) function.
## Instructions for Running Integration Tests Locally
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080.
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080 with basic AUTH TYPE enabled.
a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory.
2. Navigate to `danswer/backend`.
3. Run the following command in the terminal:

View File

@@ -132,7 +132,10 @@ class DocumentSetManager:
if not check_ids.issubset(doc_set_ids):
raise RuntimeError("Document set not found")
doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids]
all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets)
all_up_to_date = True
for doc_set in doc_sets:
if not doc_set.is_up_to_date:
all_up_to_date = False
if all_up_to_date:
break

View File

@@ -0,0 +1,159 @@
import requests
from danswer.db.models import EmbeddingProvider
from danswer.server.manage.embedding.models import TestEmbeddingRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.search_settings import (
SearchSettingsManager,
)
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
from tests.integration.common_utils.test_models import TestUser
class EmbeddingProviderManager:
@staticmethod
def test(
user_performing_action: TestUser,
embedding_provider: TestCloudEmbeddingProvider,
model_name: str | None,
) -> None:
test_embedding_request = TestEmbeddingRequest(
provider_type=embedding_provider.provider_type,
api_key=embedding_provider.api_key,
api_url=embedding_provider.api_url,
model_name=model_name,
)
response = requests.post(
url=f"{API_SERVER_URL}/admin/embedding/test-embedding",
json=test_embedding_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
if response.status_code != 200:
raise ValueError(f"Failed to test embedding provider: {response.json()}")
@staticmethod
def create(
test_embedding_provider: TestCloudEmbeddingProvider,
user_performing_action: TestUser | None = None,
) -> TestCloudEmbeddingProvider:
embedding_provider_request = {
"provider_type": test_embedding_provider.provider_type,
"api_url": test_embedding_provider.api_url,
"api_key": test_embedding_provider.api_key,
}
response = requests.put(
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
json=embedding_provider_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_data = response.json()
return TestCloudEmbeddingProvider(
provider_type=response_data["provider_type"],
api_key=response_data.get("api_key"),
api_url=response_data.get("api_url"),
)
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestCloudEmbeddingProvider]:
response = requests.get(
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [
TestCloudEmbeddingProvider(**embedding_provider)
for embedding_provider in response.json()
]
@staticmethod
def edit(
embedding_provider: TestCloudEmbeddingProvider,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.put(
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider",
json=embedding_provider.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def delete(
provider_type: EmbeddingProvider,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/admin/embedding/embedding-provider/{provider_type}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def verify_providers(
embedding_providers: list[TestCloudEmbeddingProvider],
user_performing_action: TestUser | None = None,
) -> None:
current_providers = EmbeddingProviderManager.get_all(user_performing_action)
for expected_provider in embedding_providers:
matching_provider = next(
(
p
for p in current_providers
if p.provider_type == expected_provider.provider_type
),
None,
)
if matching_provider is None:
raise ValueError(
f"Embedding provider {expected_provider.provider_type} not found in current providers"
)
if matching_provider.api_url != expected_provider.api_url:
raise ValueError(
f"API URL mismatch for provider {expected_provider.provider_type}"
)
if matching_provider.api_key != expected_provider.api_key:
raise ValueError(
f"API Key mismatch for provider {expected_provider.provider_type}"
)
@staticmethod
def verify(
embedding_provider: TestCloudEmbeddingProvider,
user_performing_action: TestUser | None = None,
) -> None:
current_settings = SearchSettingsManager.get_primary(user_performing_action)
if current_settings is None:
raise ValueError("No current embedding provider found")
current_provider_type = current_settings.provider_type
if current_provider_type is None:
raise ValueError("No current embedding provider found")
if current_provider_type != embedding_provider.provider_type:
raise ValueError(
f"Current embedding provider {current_provider_type} does not match expected {embedding_provider.provider_type}"
)

View File

@@ -0,0 +1,157 @@
import time
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import TestFullModelVersionResponse
from tests.integration.common_utils.test_models import TestSearchSettings
from tests.integration.common_utils.test_models import TestUser
class SearchSettingsManager:
@staticmethod
def create_and_set(
model_name: str = "test-model",
model_dim: int = 768,
normalize: bool = True,
query_prefix: str | None = "",
passage_prefix: str | None = "",
index_name: str | None = None,
provider_type: str | None = None,
multipass_indexing: bool = False,
multilingual_expansion: list[str] = [],
disable_rerank_for_streaming: bool = False,
rerank_model_name: str | None = None,
rerank_provider_type: str | None = None,
rerank_api_key: str | None = None,
num_rerank: int = 50,
user_performing_action: TestUser | None = None,
) -> TestSearchSettings:
search_settings_request = {
"model_name": model_name,
"model_dim": model_dim,
"normalize": normalize,
"query_prefix": query_prefix,
"passage_prefix": passage_prefix,
"index_name": index_name,
"provider_type": provider_type,
"multipass_indexing": multipass_indexing,
"multilingual_expansion": multilingual_expansion,
"disable_rerank_for_streaming": disable_rerank_for_streaming,
"rerank_model_name": rerank_model_name,
"rerank_provider_type": rerank_provider_type,
"rerank_api_key": rerank_api_key,
"num_rerank": num_rerank,
}
response = requests.post(
url=f"{API_SERVER_URL}/search-settings/set-new-search-settings",
json=search_settings_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_data = response.json()
# Merge the response data with the original request data
merged_data = {**search_settings_request, **response_data}
return TestSearchSettings(**merged_data)
@staticmethod
def wait_for_sync(
new_primary_settings: TestSearchSettings,
user_performing_action: TestUser | None = None,
) -> None:
start = time.time()
while True:
current_primary_settings = SearchSettingsManager.get_primary(
user_performing_action
)
if (
current_primary_settings.model_dump()
== new_primary_settings.model_dump()
):
break
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"Search settings were not synced within the {MAX_DELAY} seconds"
)
else:
print("Search settings were not synced yet, waiting...")
time.sleep(2)
@staticmethod
def edit(
search_settings: TestSearchSettings,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.post(
url=f"{API_SERVER_URL}/search-settings/update-inference-settings",
json=search_settings.model_dump(exclude={"id"}),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def get_primary(
user_performing_action: TestUser | None = None,
) -> TestSearchSettings:
response = requests.get(
url=f"{API_SERVER_URL}/search-settings/get-current-search-settings",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestSearchSettings(**response.json())
@staticmethod
def get_secondary(
user_performing_action: TestUser | None = None,
) -> TestSearchSettings | None:
response = requests.get(
url=f"{API_SERVER_URL}/search-settings/get-secondary-search-settings",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return TestSearchSettings(**data.json()) if data else None
@staticmethod
def get_all_current(
user_performing_action: TestUser | None = None,
) -> TestFullModelVersionResponse:
response = requests.get(
url=f"{API_SERVER_URL}/search-settings/get-all-search-settings",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return TestFullModelVersionResponse(
current_settings=TestSearchSettings(**data["current_settings"]),
secondary_settings=TestSearchSettings(**data["secondary_settings"])
if data["secondary_settings"]
else None,
)
@staticmethod
def verify(
search_settings: TestSearchSettings,
user_performing_action: TestUser | None = None,
) -> None:
current_settings = SearchSettingsManager.get_primary(user_performing_action)
if current_settings.model_dump() != search_settings.model_dump():
raise ValueError("Current search settings do not match expected settings")

View File

@@ -8,6 +8,7 @@ from danswer.auth.schemas import UserRole
from danswer.search.enums import RecencyBiasSetting
from danswer.server.documents.models import DocumentSource
from danswer.server.documents.models import InputType
from shared_configs.enums import EmbeddingProvider
"""
These data models are used to represent the data on the testing side of things.
@@ -120,6 +121,39 @@ class TestPersona(BaseModel):
groups: list[int]
class TestSearchSettings(BaseModel):
model_name: str
model_dim: int
normalize: bool
query_prefix: str | None
passage_prefix: str | None
index_name: str
provider_type: str | None
multipass_indexing: bool
multilingual_expansion: list[str]
disable_rerank_for_streaming: bool
rerank_model_name: str | None
rerank_provider_type: str | None
rerank_api_key: str | None = None
num_rerank: int
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TestSearchSettings):
return False
return self.dict() == other.dict()
class TestCloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None
api_url: str | None
class TestFullModelVersionResponse(BaseModel):
current_settings: TestSearchSettings | None
secondary_settings: TestSearchSettings | None
#
class TestChatSession(BaseModel):
id: int

View File

@@ -0,0 +1,68 @@
"""
This file tests search settings creation, upgrades, and embedding provider management.
"""
import os
from danswer.db.models import EmbeddingProvider
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
from tests.integration.tests.search_settings.utils import (
create_and_test_embedding_provider,
)
def test_creating_openai_embedding_provider(reset: None) -> None:
openai_api_key = os.getenv("OPENAI_API_KEY")
test_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.OPENAI,
api_key=openai_api_key,
api_url=None,
)
create_and_test_embedding_provider(test_embedding_provider)
def test_creating_cohere_embedding_provider(reset: None) -> None:
cohere_api_key = os.getenv("COHERE_API_KEY")
test_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.COHERE,
api_key=cohere_api_key,
api_url=None,
)
create_and_test_embedding_provider(test_embedding_provider)
def test_creating_voyage_embedding_provider(reset: None) -> None:
voyage_api_key = os.getenv("VOYAGE_API_KEY")
test_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.VOYAGE,
api_key=voyage_api_key,
api_url=None,
)
create_and_test_embedding_provider(test_embedding_provider)
def test_creating_google_embedding_provider(reset: None) -> None:
google_api_key = os.getenv("GOOGLE_API_KEY")
test_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.GOOGLE,
api_key=google_api_key,
api_url=None,
)
create_and_test_embedding_provider(test_embedding_provider)
def test_creating_litellm_embedding_provider(reset: None) -> None:
litellm_api_key = os.getenv("LITELLM_API_KEY")
litellm_api_url = os.getenv("LITELLM_API_URL")
litellm_test_model_name = os.getenv("LITELLM_TEST_MODEL_NAME")
test_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.LITELLM,
api_key=litellm_api_key,
api_url=litellm_api_url,
)
create_and_test_embedding_provider(
test_embedding_provider,
model_name=litellm_test_model_name,
)

View File

@@ -0,0 +1,73 @@
"""
This file tests search settings creation, upgrades, and embedding provider management.
"""
import os
from danswer.db.models import EmbeddingProvider
from tests.integration.common_utils.managers.embedding_provider import (
EmbeddingProviderManager,
)
from tests.integration.common_utils.managers.search_settings import (
SearchSettingsManager,
)
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
from tests.integration.common_utils.test_models import TestUser
def test_creating_and_upgrading_search_settings(reset: None) -> None:
admin_user: TestUser = UserManager.create(name="admin_user")
openai_api_key = os.environ.get("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
current_embedding_provider = TestCloudEmbeddingProvider(
provider_type=EmbeddingProvider.OPENAI, api_key=openai_api_key, api_url=None
)
EmbeddingProviderManager.create(
test_embedding_provider=current_embedding_provider,
user_performing_action=admin_user,
)
initial_settings = SearchSettingsManager.create_and_set(
model_name="text-embedding-3-small",
model_dim=1536,
provider_type=EmbeddingProvider.OPENAI,
user_performing_action=admin_user,
)
SearchSettingsManager.wait_for_sync(
new_primary_settings=initial_settings, user_performing_action=admin_user
)
SearchSettingsManager.verify(initial_settings, user_performing_action=admin_user)
new_settings = SearchSettingsManager.create_and_set(
model_name="text-embedding-3-small",
model_dim=1536,
provider_type=EmbeddingProvider.OPENAI,
user_performing_action=admin_user,
)
SearchSettingsManager.wait_for_sync(
new_primary_settings=new_settings, user_performing_action=admin_user
)
SearchSettingsManager.verify(new_settings, user_performing_action=admin_user)
EmbeddingProviderManager.verify(
current_embedding_provider, user_performing_action=admin_user
)
new_api_key = "new-openai-api-key-example"
current_embedding_provider.api_key = new_api_key
EmbeddingProviderManager.edit(
current_embedding_provider, user_performing_action=admin_user
)
EmbeddingProviderManager.verify(
current_embedding_provider, user_performing_action=admin_user
)

View File

@@ -0,0 +1,32 @@
"""
This module contains utility functions and fixtures for testing search settings.
"""
from tests.integration.common_utils.managers.embedding_provider import (
EmbeddingProviderManager,
)
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestCloudEmbeddingProvider
from tests.integration.common_utils.test_models import TestUser
def create_and_test_embedding_provider(
test_embedding_provider: TestCloudEmbeddingProvider,
model_name: str | None = None,
) -> None:
admin_user: TestUser = UserManager.create(name="admin_user")
EmbeddingProviderManager.test(
user_performing_action=admin_user,
embedding_provider=test_embedding_provider,
model_name=model_name,
)
EmbeddingProviderManager.create(
test_embedding_provider=test_embedding_provider,
user_performing_action=admin_user,
)
EmbeddingProviderManager.verify_providers(
embedding_providers=[test_embedding_provider],
user_performing_action=admin_user,
)