Compare commits

...

12 Commits

Author SHA1 Message Date
pablodanswer
fb99b30930 close to functioning 2024-09-19 19:18:57 -07:00
pablodanswer
ba6296bc81 update some types 2024-09-19 17:47:13 -07:00
pablodanswer
9abde19e44 update interfaces to standardize 2024-09-19 17:41:57 -07:00
pablodanswer
3cc99cf79a squash 2024-09-19 16:07:03 -07:00
pablodanswer
14871a62ad new llm providers 2024-09-19 13:06:11 -07:00
pablodanswer
fe5df3c8ee update integration test models 2024-09-19 12:54:27 -07:00
pablodanswer
3f0895682c additional clarity for llm provider creation / updates 2024-09-14 11:33:21 -07:00
pablodanswer
3b70a94fcd update typing 2024-09-14 11:00:13 -07:00
pablodanswer
074342165b new model structure concept 2024-09-13 20:38:03 -07:00
pablodanswer
7b91beb3b2 quick update to some naming 2024-09-13 16:52:25 -07:00
pablodanswer
4b2e4ca159 update deployment 2024-09-13 16:40:57 -07:00
pablodanswer
001bbb89cc obfuscate api keys 2024-09-13 16:39:54 -07:00
20 changed files with 289 additions and 56 deletions

View File

@@ -12,6 +12,9 @@ from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
from danswer.server.manage.llm.models import LLMProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderUpdateRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from shared_configs.enums import EmbeddingProvider
@@ -59,14 +62,69 @@ def upsert_cloud_embedding_provider(
return CloudEmbeddingProvider.from_request(existing_provider)
def update_llm_provider(
llm_provider_update: LLMProviderUpdateRequest,
db_session: Session,
) -> FullLLMProviderSnapshot:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.name == llm_provider_update.name
)
)
if not existing_llm_provider:
raise ValueError(
f"LLM Provider with name {llm_provider_update.name} does not exist"
)
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(existing_llm_provider)
)
def create_llm_provider(
llm_provider_creation: LLMProviderCreationRequest,
db_session: Session,
) -> FullLLMProviderSnapshot:
new_llm_provider = LLMProviderModel(name=llm_provider_creation.name)
db_session.add(new_llm_provider)
db_session.commit()
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(new_llm_provider)
)
def get_llm_provider(
llm_provider_name: str, db_session: Session, user: User | None = None
) -> FullLLMProviderSnapshot:
if not user or not user.is_admin:
raise ValueError("User does not have access to this LLM Provider")
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(
db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.name == llm_provider_name
)
)
)
)
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest, db_session: Session
) -> FullLLMProvider:
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProviderSnapshot:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
@@ -94,7 +152,9 @@ def upsert_llm_provider(
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(existing_llm_provider)
)
def fetch_existing_embedding_providers(

View File

@@ -183,10 +183,11 @@ def update_current_search_settings(
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
if (
current_settings.provider_type is None
search_settings.rerank_provider_type is None
and search_settings.rerank_model_name is not None
and current_settings.rerank_model_name != search_settings.rerank_model_name
):
print("WARMIGN THIS STUFF UP!")
warm_up_cross_encoder(search_settings.rerank_model_name)
update_search_settings(current_settings, search_settings, preserved_fields)

View File

@@ -17,7 +17,6 @@ from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
)
@@ -53,14 +52,23 @@ class InferenceSettings(RerankingDetails):
class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
api_key_set: bool = False
rerank_api_key_set: bool = False
@classmethod
def from_db_model(
cls, search_settings: SearchSettings
) -> "SearchSettingsCreationRequest":
inference_settings = InferenceSettings.from_db_model(search_settings)
reranking_details = RerankingDetails.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
return cls(
**reranking_details.dict(),
**indexing_setting.dict(),
api_key_set=bool(search_settings.api_key),
rerank_api_key_set=bool(search_settings.rerank_api_key),
multilingual_expansion=search_settings.multilingual_expansion,
)
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@@ -87,6 +95,27 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
)
class SearchSettingsSnapshot(SavedSearchSettings):
rerank_api_key: None = None
api_key: None = None
rerank_api_key_set: bool
api_key_set: bool
@classmethod
def from_saved_settings(
cls, settings: SavedSearchSettings
) -> "SearchSettingsSnapshot":
data = settings.dict(exclude={"rerank_api_key", "api_key"})
data["rerank_api_key_set"] = bool(settings.rerank_api_key)
data["api_key_set"] = bool(settings.api_key)
return cls(**data)
@classmethod
def from_db_model(cls, settings: SearchSettings) -> "SearchSettingsSnapshot":
return cls.from_saved_settings(SavedSearchSettings.from_db_model(settings))
class Tag(BaseModel):
tag_key: str
tag_value: str

View File

@@ -3,14 +3,18 @@ from collections.abc import Callable
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import create_llm_provider
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import get_llm_provider
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import update_llm_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User
from danswer.llm.factory import get_default_llms
@@ -18,7 +22,8 @@ from danswer.llm.factory import get_llm
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
from danswer.server.manage.llm.models import LLMProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderDescriptor
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.llm.models import TestLLMRequest
@@ -42,8 +47,15 @@ def fetch_llm_options(
@admin_router.post("/test")
def test_llm_configuration(
test_llm_request: TestLLMRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
if test_llm_request.existing_api_key and not test_llm_request.api_key:
llm_provider = get_llm_provider(
test_llm_request.provider.name, db_session=db_session
)
test_llm_request.api_key = llm_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
@@ -108,20 +120,72 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
) -> list[FullLLMProviderSnapshot]:
print(
[
FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
)
return [
FullLLMProvider.from_model(llm_provider_model)
FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
@admin_router.patch("/provider/{provider_id}")
def patch_existing_llm_provider(
llm_provider: LLMProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
update_llm_provider(llm_provider=llm_provider, db_session=db_session)
)
@admin_router.post("/provider")
def create_new_llm_provider(
llm_provider: LLMProviderCreationRequest,
is_creation: bool = Query(
True,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
create_llm_provider(
llm_provider=llm_provider, db_session=db_session, is_creation=is_creation
)
)
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
True,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session)
) -> FullLLMProviderSnapshot:
try:
print("hitting htis function")
value = FullLLMProviderSnapshot.from_full_llm_provider(
upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
)
print(value)
return value
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")

View File

@@ -5,13 +5,13 @@ from pydantic import Field
from danswer.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel
class TestLLMRequest(BaseModel):
# provider level
existing_api_key: bool = False
provider: str
api_key: str | None = None
api_base: str | None = None
@@ -74,6 +74,14 @@ class LLMProviderUpsertRequest(LLMProvider):
model_names: list[str] | None = None
class LLMProviderUpdateRequest(LLMProvider):
api_key_set: bool
class LLMProviderCreationRequest(LLMProvider):
pass
class FullLLMProvider(LLMProvider):
id: int
is_default_provider: bool | None = None
@@ -82,10 +90,10 @@ class FullLLMProvider(LLMProvider):
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
return cls(
api_key=llm_provider_model.api_key,
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=llm_provider_model.provider,
api_key=llm_provider_model.api_key,
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
@@ -101,3 +109,23 @@ class FullLLMProvider(LLMProvider):
is_public=llm_provider_model.is_public,
groups=[group.id for group in llm_provider_model.groups],
)
class FullLLMProviderSnapshot(FullLLMProvider):
api_key: None = None
api_key_set: bool
@classmethod
def from_full_llm_provider(
cls, settings: FullLLMProvider
) -> "FullLLMProviderSnapshot":
data = settings.dict(exclude={"api_key"})
data["api_key_set"] = bool(settings.api_key)
return cls(**data)
@classmethod
def from_model(
cls, llm_provider_model: "LLMProviderModel"
) -> "FullLLMProviderSnapshot":
full_provider = FullLLMProvider.from_model(llm_provider_model)
return cls.from_full_llm_provider(full_provider)

View File

@@ -24,6 +24,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.search.models import SearchSettingsCreationRequest
from danswer.search.models import SearchSettingsSnapshot
from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.models import IdReturn
@@ -73,7 +74,13 @@ def set_new_search_settings(
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
new_search_settings_request = SavedSearchSettings(
**{
k: v
for k, v in search_settings_new.dict().items()
if k not in ["api_key_set", "rerank_api_key_set"]
}
)
secondary_search_settings = get_secondary_search_settings(db_session)
@@ -154,21 +161,21 @@ def delete_search_settings_endpoint(
def get_current_search_settings_endpoint(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SavedSearchSettings:
) -> SearchSettingsSnapshot:
current_search_settings = get_current_search_settings(db_session)
return SavedSearchSettings.from_db_model(current_search_settings)
return SearchSettingsSnapshot.from_db_model(current_search_settings)
@router.get("/get-secondary-search-settings")
def get_secondary_search_settings_endpoint(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SavedSearchSettings | None:
) -> SearchSettingsSnapshot | None:
secondary_search_settings = get_secondary_search_settings(db_session)
if not secondary_search_settings:
return None
return SavedSearchSettings.from_db_model(secondary_search_settings)
return SearchSettingsSnapshot.from_db_model(secondary_search_settings)
@router.get("/get-all-search-settings")

View File

@@ -9,3 +9,9 @@ def batch_list(
batch_size: int,
) -> list[list[T]]:
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
def obfuscate_api_key(api_key: str | None) -> str | None:
if api_key is None:
return None
return "*" * len(api_key)

View File

@@ -7,6 +7,7 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestLLMProvider
from tests.integration.common_utils.test_models import TestLLMProviderResponse
from tests.integration.common_utils.test_models import TestUser
@@ -22,7 +23,7 @@ class LLMProviderManager:
groups: list[int] | None = None,
is_public: bool | None = None,
user_performing_action: TestUser | None = None,
) -> TestLLMProvider:
) -> TestLLMProviderResponse:
print("Seeding LLM Providers...")
llm_provider = LLMProviderUpsertRequest(
@@ -49,11 +50,11 @@ class LLMProviderManager:
)
llm_response.raise_for_status()
response_data = llm_response.json()
result_llm = TestLLMProvider(
result_llm = TestLLMProviderResponse(
id=response_data["id"],
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
api_key_set=response_data["api_key_set"],
default_model_name=response_data["default_model_name"],
is_public=response_data["is_public"],
groups=response_data["groups"],

View File

@@ -91,6 +91,18 @@ class TestLLMProvider(BaseModel):
api_version: str | None = None
class TestLLMProviderResponse(BaseModel):
id: int
name: str
provider: str
api_key_set: bool
default_model_name: str
is_public: bool
groups: list[TestUserGroup]
api_base: str | None = None
api_version: str | None = None
class TestDocumentSet(BaseModel):
id: int
name: str

View File

@@ -292,7 +292,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -302,7 +302,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -154,7 +154,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
BooleanFormField,
MultiSelectField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import {
defaultModelsByProvider,
getDisplayNameForModel,
useUserGroups,
} from "@/lib/hooks";
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import { defaultPasswordMask } from "@/lib/llm/utils";
export function LLMProviderUpdateForm({
llmProviderDescriptor,
@@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault,
setPopup,
hideAdvanced,
llmProviderFlow,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
@@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void;
llmProviderFlow: "create" | "update";
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
@@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
api_key: existingLlmProvider?.api_key ?? "",
api_key: null,
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
default_model_name:
@@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({
// Setup validation schema if required
const validationSchema = Yup.object({
name: Yup.string().required("Display Name is required"),
api_key: llmProviderDescriptor.api_key_required
? Yup.string().required("API Key is required")
: Yup.string(),
api_key:
llmProviderDescriptor.api_key_required && llmProviderFlow == "create"
? Yup.string().required("API Key is required")
: Yup.string().nullable(),
api_base: llmProviderDescriptor.api_base_required
? Yup.string().required("API Base is required")
: Yup.string(),
@@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({
display_model_names: Yup.array().of(Yup.string()),
});
const apiKeyDefault = existingLlmProvider?.api_key_set
? defaultPasswordMask
: "API key";
return (
<Formik
initialValues={initialValues}
@@ -137,6 +132,10 @@ export function LLMProviderUpdateForm({
"Content-Type": "application/json",
},
body: JSON.stringify({
existing_api_key:
llmProviderFlow == "update" &&
values.api_key == null &&
existingLlmProvider?.name != undefined,
provider: llmProviderDescriptor.name,
...values,
}),
@@ -151,7 +150,7 @@ export function LLMProviderUpdateForm({
}
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
method: "PUT",
method: llmProviderFlow == "create" ? "POST" : "PUT",
headers: {
"Content-Type": "application/json",
},
@@ -237,8 +236,8 @@ export function LLMProviderUpdateForm({
small={hideAdvanced}
name="api_key"
label="API Key"
placeholder="API Key"
type="password"
placeholder={formikProps.values.api_key ?? apiKeyDefault}
/>
)}

View File

@@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor {
export interface LLMProvider {
name: string;
provider: string;
api_key: string | null;
api_key_set: boolean;
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;

View File

@@ -24,7 +24,7 @@ export interface EmbeddingDetails {
import { EmbeddingIcon } from "@/components/icons/icons";
import Link from "next/link";
import { SavedSearchSettings } from "../../embeddings/interfaces";
import { SearchSettingsSnapshot } from "../../embeddings/interfaces";
import UpgradingPage from "./UpgradingPage";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
@@ -42,7 +42,7 @@ function Main() {
);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
useSWR<SearchSettingsSnapshot | null>(
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds

View File

@@ -15,6 +15,7 @@ import {
import { Modal } from "@/components/Modal";
import { Button } from "@tremor/react";
import { TextFormField } from "@/components/admin/connectors/Field";
import { defaultPasswordMask } from "@/lib/llm/utils";
interface RerankingDetailsFormProps {
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
@@ -38,6 +39,7 @@ const RerankingDetailsForm = forwardRef<
},
ref
) => {
console.log(originalRerankingDetails);
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
useState(false);
@@ -254,8 +256,8 @@ const RerankingDetailsForm = forwardRef<
<TextFormField
subtext="Set the key to access your LiteLLM Proxy"
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
originalRerankingDetails.rerank_api_key_set
? defaultPasswordMask
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
@@ -327,8 +329,8 @@ const RerankingDetailsForm = forwardRef<
<div className="w-full px-4">
<TextFormField
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
originalRerankingDetails.rerank_api_key_set
? defaultPasswordMask
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {

View File

@@ -7,6 +7,11 @@ export interface RerankingDetails {
rerank_provider_type: RerankerProvider | null;
rerank_api_key: string | null;
rerank_api_url: string | null;
rerank_api_key_set?: boolean;
}
export interface RerankingDetailsSnapshot
extends Omit<RerankingDetails, "rerank_api_key"> {
rerank_api_key_set: boolean;
}
export enum RerankerProvider {
@@ -34,6 +39,14 @@ export interface SavedSearchSettings
provider_type: EmbeddingProvider | null;
}
export interface SearchSettingsSnapshot
extends Omit<RerankingDetails, "rerank_api_key">,
Omit<AdvancedSearchConfiguration, "api_url"> {
provider_type: EmbeddingProvider | null;
rerank_api_key_set: boolean;
api_key_set: boolean;
}
export interface RerankingModel {
rerank_provider_type: RerankerProvider | null;
modelName?: string;

View File

@@ -20,6 +20,7 @@ import {
AdvancedSearchConfiguration,
RerankingDetails,
SavedSearchSettings,
SearchSettingsSnapshot,
} from "../interfaces";
import RerankingDetailsForm from "../RerankingFormPage";
import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext";
@@ -98,7 +99,7 @@ export default function EmbeddingForm() {
>(currentEmbeddingModel!);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
useSWR<SearchSettingsSnapshot | null>(
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@@ -122,26 +123,29 @@ export default function EmbeddingForm() {
});
setRerankingDetails({
rerank_api_key: searchSettings.rerank_api_key,
rerank_api_key: null,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
rerank_api_key_set: searchSettings.rerank_api_key_set,
});
}
}, [searchSettings]);
const originalRerankingDetails: RerankingDetails = searchSettings
? {
rerank_api_key: searchSettings.rerank_api_key,
rerank_api_key: null,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
rerank_api_key_set: searchSettings.rerank_api_key_set,
}
: {
rerank_api_key: "",
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
rerank_api_key_set: false,
};
useEffect(() => {
@@ -173,7 +177,7 @@ export default function EmbeddingForm() {
const response = await updateSearchSettings(values);
if (response.ok) {
setPopup({
message: "Updated search settings succesffuly",
message: "Updated search settings successfully",
type: "success",
});
mutate("/api/search-settings/get-current-search-settings");

View File

@@ -52,6 +52,11 @@ export interface EmbeddingModelDescriptor {
index_name: string | null;
}
export interface EmbeddingModelSnapshot
extends Omit<CloudEmbeddingModel, "api_key"> {
api_key_set: boolean;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
pricePerMillion: number;
enabled?: boolean;

View File

@@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => {
modelName,
};
};
export const defaultPasswordMask = "**************************";