mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
12 Commits
v0.20.0-cl
...
obfuscate_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb99b30930 | ||
|
|
ba6296bc81 | ||
|
|
9abde19e44 | ||
|
|
3cc99cf79a | ||
|
|
14871a62ad | ||
|
|
fe5df3c8ee | ||
|
|
3f0895682c | ||
|
|
3b70a94fcd | ||
|
|
074342165b | ||
|
|
7b91beb3b2 | ||
|
|
4b2e4ca159 | ||
|
|
001bbb89cc |
@@ -12,6 +12,9 @@ from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
|
||||
from danswer.server.manage.llm.models import LLMProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpdateRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -59,14 +62,69 @@ def upsert_cloud_embedding_provider(
|
||||
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||
|
||||
|
||||
def update_llm_provider(
|
||||
llm_provider_update: LLMProviderUpdateRequest,
|
||||
db_session: Session,
|
||||
) -> FullLLMProviderSnapshot:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.name == llm_provider_update.name
|
||||
)
|
||||
)
|
||||
if not existing_llm_provider:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider_update.name} does not exist"
|
||||
)
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
FullLLMProvider.from_model(existing_llm_provider)
|
||||
)
|
||||
|
||||
|
||||
def create_llm_provider(
|
||||
llm_provider_creation: LLMProviderCreationRequest,
|
||||
db_session: Session,
|
||||
) -> FullLLMProviderSnapshot:
|
||||
new_llm_provider = LLMProviderModel(name=llm_provider_creation.name)
|
||||
db_session.add(new_llm_provider)
|
||||
db_session.commit()
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
FullLLMProvider.from_model(new_llm_provider)
|
||||
)
|
||||
|
||||
|
||||
def get_llm_provider(
|
||||
llm_provider_name: str, db_session: Session, user: User | None = None
|
||||
) -> FullLLMProviderSnapshot:
|
||||
if not user or not user.is_admin:
|
||||
raise ValueError("User does not have access to this LLM Provider")
|
||||
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
FullLLMProvider.from_model(
|
||||
db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.name == llm_provider_name
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest, db_session: Session
|
||||
) -> FullLLMProvider:
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
is_creation: bool = True,
|
||||
) -> FullLLMProviderSnapshot:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider and is_creation:
|
||||
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
|
||||
|
||||
if not existing_llm_provider:
|
||||
if not is_creation:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider.name} does not exist"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -94,7 +152,9 @@ def upsert_llm_provider(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
FullLLMProvider.from_model(existing_llm_provider)
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
|
||||
@@ -183,10 +183,11 @@ def update_current_search_settings(
|
||||
|
||||
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
|
||||
if (
|
||||
current_settings.provider_type is None
|
||||
search_settings.rerank_provider_type is None
|
||||
and search_settings.rerank_model_name is not None
|
||||
and current_settings.rerank_model_name != search_settings.rerank_model_name
|
||||
):
|
||||
print("WARMIGN THIS STUFF UP!")
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
update_search_settings(current_settings, search_settings, preserved_fields)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
@@ -53,14 +52,23 @@ class InferenceSettings(RerankingDetails):
|
||||
|
||||
|
||||
class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
|
||||
api_key_set: bool = False
|
||||
rerank_api_key_set: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_db_model(
|
||||
cls, search_settings: SearchSettings
|
||||
) -> "SearchSettingsCreationRequest":
|
||||
inference_settings = InferenceSettings.from_db_model(search_settings)
|
||||
reranking_details = RerankingDetails.from_db_model(search_settings)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
return cls(**inference_settings.dict(), **indexing_setting.dict())
|
||||
return cls(
|
||||
**reranking_details.dict(),
|
||||
**indexing_setting.dict(),
|
||||
api_key_set=bool(search_settings.api_key),
|
||||
rerank_api_key_set=bool(search_settings.rerank_api_key),
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
)
|
||||
|
||||
|
||||
class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
@@ -87,6 +95,27 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
)
|
||||
|
||||
|
||||
class SearchSettingsSnapshot(SavedSearchSettings):
|
||||
rerank_api_key: None = None
|
||||
api_key: None = None
|
||||
rerank_api_key_set: bool
|
||||
api_key_set: bool
|
||||
|
||||
@classmethod
|
||||
def from_saved_settings(
|
||||
cls, settings: SavedSearchSettings
|
||||
) -> "SearchSettingsSnapshot":
|
||||
data = settings.dict(exclude={"rerank_api_key", "api_key"})
|
||||
data["rerank_api_key_set"] = bool(settings.rerank_api_key)
|
||||
data["api_key_set"] = bool(settings.api_key)
|
||||
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, settings: SearchSettings) -> "SearchSettingsSnapshot":
|
||||
return cls.from_saved_settings(SavedSearchSettings.from_db_model(settings))
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
tag_key: str
|
||||
tag_value: str
|
||||
|
||||
@@ -3,14 +3,18 @@ from collections.abc import Callable
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import create_llm_provider
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import get_llm_provider
|
||||
from danswer.db.llm import remove_llm_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import update_llm_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.factory import get_default_llms
|
||||
@@ -18,7 +22,8 @@ from danswer.llm.factory import get_llm
|
||||
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
|
||||
from danswer.server.manage.llm.models import LLMProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderDescriptor
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.manage.llm.models import TestLLMRequest
|
||||
@@ -42,8 +47,15 @@ def fetch_llm_options(
|
||||
@admin_router.post("/test")
|
||||
def test_llm_configuration(
|
||||
test_llm_request: TestLLMRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if test_llm_request.existing_api_key and not test_llm_request.api_key:
|
||||
llm_provider = get_llm_provider(
|
||||
test_llm_request.provider.name, db_session=db_session
|
||||
)
|
||||
test_llm_request.api_key = llm_provider.api_key
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
@@ -108,20 +120,72 @@ def test_default_provider(
|
||||
def list_llm_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullLLMProvider]:
|
||||
) -> list[FullLLMProviderSnapshot]:
|
||||
print(
|
||||
[
|
||||
FullLLMProviderSnapshot.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
)
|
||||
return [
|
||||
FullLLMProvider.from_model(llm_provider_model)
|
||||
FullLLMProviderSnapshot.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.patch("/provider/{provider_id}")
|
||||
def patch_existing_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProviderSnapshot:
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
update_llm_provider(llm_provider=llm_provider, db_session=db_session)
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider")
|
||||
def create_new_llm_provider(
|
||||
llm_provider: LLMProviderCreationRequest,
|
||||
is_creation: bool = Query(
|
||||
True,
|
||||
description="True if updating an existing provider, False if creating a new one",
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProviderSnapshot:
|
||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
create_llm_provider(
|
||||
llm_provider=llm_provider, db_session=db_session, is_creation=is_creation
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
def put_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
is_creation: bool = Query(
|
||||
True,
|
||||
description="True if updating an existing provider, False if creating a new one",
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session)
|
||||
) -> FullLLMProviderSnapshot:
|
||||
try:
|
||||
print("hitting htis function")
|
||||
|
||||
value = FullLLMProviderSnapshot.from_full_llm_provider(
|
||||
upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
db_session=db_session,
|
||||
is_creation=is_creation,
|
||||
)
|
||||
)
|
||||
print(value)
|
||||
return value
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
|
||||
@@ -5,13 +5,13 @@ from pydantic import Field
|
||||
|
||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
existing_api_key: bool = False
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
@@ -74,6 +74,14 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
model_names: list[str] | None = None
|
||||
|
||||
|
||||
class LLMProviderUpdateRequest(LLMProvider):
|
||||
api_key_set: bool
|
||||
|
||||
|
||||
class LLMProviderCreationRequest(LLMProvider):
|
||||
pass
|
||||
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
@@ -82,10 +90,10 @@ class FullLLMProvider(LLMProvider):
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
||||
return cls(
|
||||
api_key=llm_provider_model.api_key,
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=llm_provider_model.provider,
|
||||
api_key=llm_provider_model.api_key,
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
@@ -101,3 +109,23 @@ class FullLLMProvider(LLMProvider):
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
||||
|
||||
class FullLLMProviderSnapshot(FullLLMProvider):
|
||||
api_key: None = None
|
||||
api_key_set: bool
|
||||
|
||||
@classmethod
|
||||
def from_full_llm_provider(
|
||||
cls, settings: FullLLMProvider
|
||||
) -> "FullLLMProviderSnapshot":
|
||||
data = settings.dict(exclude={"api_key"})
|
||||
data["api_key_set"] = bool(settings.api_key)
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, llm_provider_model: "LLMProviderModel"
|
||||
) -> "FullLLMProviderSnapshot":
|
||||
full_provider = FullLLMProvider.from_model(llm_provider_model)
|
||||
return cls.from_full_llm_provider(full_provider)
|
||||
|
||||
@@ -24,6 +24,7 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.models import SearchSettingsCreationRequest
|
||||
from danswer.search.models import SearchSettingsSnapshot
|
||||
from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from danswer.server.manage.models import FullModelVersionResponse
|
||||
from danswer.server.models import IdReturn
|
||||
@@ -73,7 +74,13 @@ def set_new_search_settings(
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**{
|
||||
k: v
|
||||
for k, v in search_settings_new.dict().items()
|
||||
if k not in ["api_key_set", "rerank_api_key_set"]
|
||||
}
|
||||
)
|
||||
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
@@ -154,21 +161,21 @@ def delete_search_settings_endpoint(
|
||||
def get_current_search_settings_endpoint(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SavedSearchSettings:
|
||||
) -> SearchSettingsSnapshot:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
return SavedSearchSettings.from_db_model(current_search_settings)
|
||||
return SearchSettingsSnapshot.from_db_model(current_search_settings)
|
||||
|
||||
|
||||
@router.get("/get-secondary-search-settings")
|
||||
def get_secondary_search_settings_endpoint(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SavedSearchSettings | None:
|
||||
) -> SearchSettingsSnapshot | None:
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if not secondary_search_settings:
|
||||
return None
|
||||
|
||||
return SavedSearchSettings.from_db_model(secondary_search_settings)
|
||||
return SearchSettingsSnapshot.from_db_model(secondary_search_settings)
|
||||
|
||||
|
||||
@router.get("/get-all-search-settings")
|
||||
|
||||
@@ -9,3 +9,9 @@ def batch_list(
|
||||
batch_size: int,
|
||||
) -> list[list[T]]:
|
||||
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
|
||||
|
||||
|
||||
def obfuscate_api_key(api_key: str | None) -> str | None:
|
||||
if api_key is None:
|
||||
return None
|
||||
return "*" * len(api_key)
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestLLMProvider
|
||||
from tests.integration.common_utils.test_models import TestLLMProviderResponse
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ class LLMProviderManager:
|
||||
groups: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestLLMProvider:
|
||||
) -> TestLLMProviderResponse:
|
||||
print("Seeding LLM Providers...")
|
||||
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
@@ -49,11 +50,11 @@ class LLMProviderManager:
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
result_llm = TestLLMProvider(
|
||||
result_llm = TestLLMProviderResponse(
|
||||
id=response_data["id"],
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
api_key_set=response_data["api_key_set"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
groups=response_data["groups"],
|
||||
|
||||
@@ -91,6 +91,18 @@ class TestLLMProvider(BaseModel):
|
||||
api_version: str | None = None
|
||||
|
||||
|
||||
class TestLLMProviderResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
api_key_set: bool
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
groups: list[TestUserGroup]
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
|
||||
|
||||
class TestDocumentSet(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@@ -292,7 +292,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432"
|
||||
- "5433"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import {
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
BooleanFormField,
|
||||
MultiSelectField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { useSWRConfig } from "swr";
|
||||
import {
|
||||
defaultModelsByProvider,
|
||||
getDisplayNameForModel,
|
||||
useUserGroups,
|
||||
} from "@/lib/hooks";
|
||||
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import { defaultPasswordMask } from "@/lib/llm/utils";
|
||||
|
||||
export function LLMProviderUpdateForm({
|
||||
llmProviderDescriptor,
|
||||
@@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({
|
||||
shouldMarkAsDefault,
|
||||
setPopup,
|
||||
hideAdvanced,
|
||||
llmProviderFlow,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||
onClose: () => void;
|
||||
@@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideAdvanced?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
llmProviderFlow: "create" | "update";
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
// EE only
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
@@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_key: null,
|
||||
api_base: existingLlmProvider?.api_base ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
default_model_name:
|
||||
@@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({
|
||||
// Setup validation schema if required
|
||||
const validationSchema = Yup.object({
|
||||
name: Yup.string().required("Display Name is required"),
|
||||
api_key: llmProviderDescriptor.api_key_required
|
||||
? Yup.string().required("API Key is required")
|
||||
: Yup.string(),
|
||||
api_key:
|
||||
llmProviderDescriptor.api_key_required && llmProviderFlow == "create"
|
||||
? Yup.string().required("API Key is required")
|
||||
: Yup.string().nullable(),
|
||||
api_base: llmProviderDescriptor.api_base_required
|
||||
? Yup.string().required("API Base is required")
|
||||
: Yup.string(),
|
||||
@@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({
|
||||
display_model_names: Yup.array().of(Yup.string()),
|
||||
});
|
||||
|
||||
const apiKeyDefault = existingLlmProvider?.api_key_set
|
||||
? defaultPasswordMask
|
||||
: "API key";
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
@@ -137,6 +132,10 @@ export function LLMProviderUpdateForm({
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
existing_api_key:
|
||||
llmProviderFlow == "update" &&
|
||||
values.api_key == null &&
|
||||
existingLlmProvider?.name != undefined,
|
||||
provider: llmProviderDescriptor.name,
|
||||
...values,
|
||||
}),
|
||||
@@ -151,7 +150,7 @@ export function LLMProviderUpdateForm({
|
||||
}
|
||||
|
||||
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
|
||||
method: "PUT",
|
||||
method: llmProviderFlow == "create" ? "POST" : "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
@@ -237,8 +236,8 @@ export function LLMProviderUpdateForm({
|
||||
small={hideAdvanced}
|
||||
name="api_key"
|
||||
label="API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
placeholder={formikProps.values.api_key ?? apiKeyDefault}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor {
|
||||
export interface LLMProvider {
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
api_key_set: boolean;
|
||||
api_base: string | null;
|
||||
api_version: string | null;
|
||||
custom_config: { [key: string]: string } | null;
|
||||
|
||||
@@ -24,7 +24,7 @@ export interface EmbeddingDetails {
|
||||
import { EmbeddingIcon } from "@/components/icons/icons";
|
||||
|
||||
import Link from "next/link";
|
||||
import { SavedSearchSettings } from "../../embeddings/interfaces";
|
||||
import { SearchSettingsSnapshot } from "../../embeddings/interfaces";
|
||||
import UpgradingPage from "./UpgradingPage";
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
@@ -42,7 +42,7 @@ function Main() {
|
||||
);
|
||||
|
||||
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
|
||||
useSWR<SavedSearchSettings | null>(
|
||||
useSWR<SearchSettingsSnapshot | null>(
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@tremor/react";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { defaultPasswordMask } from "@/lib/llm/utils";
|
||||
|
||||
interface RerankingDetailsFormProps {
|
||||
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
|
||||
@@ -38,6 +39,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
console.log(originalRerankingDetails);
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
|
||||
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
|
||||
useState(false);
|
||||
@@ -254,8 +256,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
<TextFormField
|
||||
subtext="Set the key to access your LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
originalRerankingDetails.rerank_api_key_set
|
||||
? defaultPasswordMask
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -327,8 +329,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
<div className="w-full px-4">
|
||||
<TextFormField
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
originalRerankingDetails.rerank_api_key_set
|
||||
? defaultPasswordMask
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -7,6 +7,11 @@ export interface RerankingDetails {
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
rerank_api_key: string | null;
|
||||
rerank_api_url: string | null;
|
||||
rerank_api_key_set?: boolean;
|
||||
}
|
||||
export interface RerankingDetailsSnapshot
|
||||
extends Omit<RerankingDetails, "rerank_api_key"> {
|
||||
rerank_api_key_set: boolean;
|
||||
}
|
||||
|
||||
export enum RerankerProvider {
|
||||
@@ -34,6 +39,14 @@ export interface SavedSearchSettings
|
||||
provider_type: EmbeddingProvider | null;
|
||||
}
|
||||
|
||||
export interface SearchSettingsSnapshot
|
||||
extends Omit<RerankingDetails, "rerank_api_key">,
|
||||
Omit<AdvancedSearchConfiguration, "api_url"> {
|
||||
provider_type: EmbeddingProvider | null;
|
||||
rerank_api_key_set: boolean;
|
||||
api_key_set: boolean;
|
||||
}
|
||||
|
||||
export interface RerankingModel {
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
modelName?: string;
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
AdvancedSearchConfiguration,
|
||||
RerankingDetails,
|
||||
SavedSearchSettings,
|
||||
SearchSettingsSnapshot,
|
||||
} from "../interfaces";
|
||||
import RerankingDetailsForm from "../RerankingFormPage";
|
||||
import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext";
|
||||
@@ -98,7 +99,7 @@ export default function EmbeddingForm() {
|
||||
>(currentEmbeddingModel!);
|
||||
|
||||
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
|
||||
useSWR<SavedSearchSettings | null>(
|
||||
useSWR<SearchSettingsSnapshot | null>(
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
@@ -122,26 +123,29 @@ export default function EmbeddingForm() {
|
||||
});
|
||||
|
||||
setRerankingDetails({
|
||||
rerank_api_key: searchSettings.rerank_api_key,
|
||||
rerank_api_key: null,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
rerank_api_key_set: searchSettings.rerank_api_key_set,
|
||||
});
|
||||
}
|
||||
}, [searchSettings]);
|
||||
|
||||
const originalRerankingDetails: RerankingDetails = searchSettings
|
||||
? {
|
||||
rerank_api_key: searchSettings.rerank_api_key,
|
||||
rerank_api_key: null,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
rerank_api_key_set: searchSettings.rerank_api_key_set,
|
||||
}
|
||||
: {
|
||||
rerank_api_key: "",
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
rerank_api_key_set: false,
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
@@ -173,7 +177,7 @@ export default function EmbeddingForm() {
|
||||
const response = await updateSearchSettings(values);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Updated search settings succesffuly",
|
||||
message: "Updated search settings successfully",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/search-settings/get-current-search-settings");
|
||||
|
||||
@@ -52,6 +52,11 @@ export interface EmbeddingModelDescriptor {
|
||||
index_name: string | null;
|
||||
}
|
||||
|
||||
export interface EmbeddingModelSnapshot
|
||||
extends Omit<CloudEmbeddingModel, "api_key"> {
|
||||
api_key_set: boolean;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
pricePerMillion: number;
|
||||
enabled?: boolean;
|
||||
|
||||
@@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => {
|
||||
modelName,
|
||||
};
|
||||
};
|
||||
|
||||
export const defaultPasswordMask = "**************************";
|
||||
|
||||
Reference in New Issue
Block a user