mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
3 Commits
v2.0.0-clo
...
feature/ll
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b9057bdd7 | ||
|
|
e53ceb093f | ||
|
|
2525b8d53f |
@@ -271,6 +271,31 @@ class CloudEmbedding:
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
async def _embed_llama_stack(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
raise ValueError("Model name is required for Llama Stack embedding.")
|
||||
|
||||
if not self.api_url:
|
||||
raise ValueError("API URL is required for Llama Stack embedding.")
|
||||
|
||||
headers = (
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
async def embed(
|
||||
self,
|
||||
@@ -288,6 +313,8 @@ class CloudEmbedding:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.LLAMA_STACK:
|
||||
return await self._embed_llama_stack(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
|
||||
@@ -787,3 +787,15 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
#####
|
||||
# Llama Stack Server Configs
|
||||
#####
|
||||
LLAMA_STACK_SERVER_HOST = os.environ.get(
|
||||
"LLAMA_STACK_SERVER_HOST", "llama_stack_server"
|
||||
)
|
||||
LLAMA_STACK_SERVER_PORT = os.environ.get("LLAMA_STACK_SERVER_PORT", "8321")
|
||||
LLAMA_STACK_SERVER_URL = os.environ.get(
|
||||
"LLAMA_STACK_SERVER_URL",
|
||||
f"http://{LLAMA_STACK_SERVER_HOST}:{LLAMA_STACK_SERVER_PORT}",
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
import litellm # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import LLAMA_STACK_SERVER_URL
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
|
||||
@@ -24,6 +25,7 @@ class CustomConfigKey(BaseModel):
|
||||
is_required: bool = True
|
||||
is_secret: bool = False
|
||||
key_type: CustomConfigKeyType = CustomConfigKeyType.TEXT_INPUT
|
||||
default_value: str | None = None
|
||||
|
||||
|
||||
class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
@@ -138,12 +140,15 @@ VERTEXAI_VISIBLE_MODEL_NAMES = [
|
||||
VERTEXAI_DEFAULT_FAST_MODEL,
|
||||
]
|
||||
|
||||
LLAMA_STACK_PROVIDER_NAME = "openai"
|
||||
LLAMA_STACK_MODEL_NAMES: list[str] = []
|
||||
|
||||
_PROVIDER_TO_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
LLAMA_STACK_PROVIDER_NAME: LLAMA_STACK_MODEL_NAMES,
|
||||
}
|
||||
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
@@ -151,6 +156,7 @@ _PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
LLAMA_STACK_PROVIDER_NAME: [], # No pre-selected visible models
|
||||
}
|
||||
|
||||
|
||||
@@ -172,6 +178,35 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
default_model="gpt-4o",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=LLAMA_STACK_PROVIDER_NAME,
|
||||
display_name="Llama Stack",
|
||||
api_key_required=False,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[
|
||||
CustomConfigKey(
|
||||
name="LLAMA_STACK_SERVER_URL",
|
||||
display_name="Llama Stack Server URL",
|
||||
description="The base URL for your Llama Stack server.",
|
||||
is_required=True,
|
||||
is_secret=False,
|
||||
default_value=LLAMA_STACK_SERVER_URL,
|
||||
),
|
||||
CustomConfigKey(
|
||||
name="LLAMA_STACK_API_KEY", # TODO: need to accept multiple API keys
|
||||
display_name="Llama Stack LLM API Key",
|
||||
description="The API key for your LLM provider on Llama Stack.",
|
||||
is_required=False,
|
||||
is_secret=True,
|
||||
),
|
||||
],
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
LLAMA_STACK_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
display_name="Anthropic",
|
||||
|
||||
@@ -2,10 +2,12 @@ from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -386,3 +388,50 @@ def get_provider_contextual_cost(
|
||||
)
|
||||
|
||||
return costs
|
||||
|
||||
|
||||
@admin_router.post("/llama-stack-models")
|
||||
async def fetch_llama_stack_models(
|
||||
request: Request,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""Proxy endpoint to fetch models from a Llama Stack server URL, filtered by model type.
|
||||
|
||||
Request body should contain:
|
||||
- server_url: URL of the Llama Stack server
|
||||
- model_type: Optional type to filter by ("llm" or "embedding"), defaults to "llm"
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
server_url = body.get("server_url") if body else None
|
||||
model_type = body.get("model_type", "llm") if body else "llm"
|
||||
except Exception:
|
||||
server_url = None
|
||||
model_type = "llm"
|
||||
if not server_url:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing server_url in request body."
|
||||
)
|
||||
try:
|
||||
llama_url = server_url.rstrip("/") + "/v1/models"
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(llama_url, timeout=10)
|
||||
resp.raise_for_status()
|
||||
resp_json = resp.json()
|
||||
models = resp_json.get("data")
|
||||
if not isinstance(models, list):
|
||||
raise Exception(
|
||||
"Invalid response from Llama Stack server: missing 'data' list."
|
||||
)
|
||||
|
||||
# Filter models to only include those with the specified model_type
|
||||
filtered_models = [
|
||||
model for model in models if model.get("model_type") == model_type
|
||||
]
|
||||
|
||||
return {"models": filtered_models}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch models from Llama Stack server: {str(e)}",
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ class EmbeddingProvider(str, Enum):
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
LITELLM = "litellm"
|
||||
LLAMA_STACK = "llama_stack"
|
||||
AZURE = "azure"
|
||||
|
||||
|
||||
|
||||
BIN
web/public/Meta.png
Normal file
BIN
web/public/Meta.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 103 KiB |
@@ -51,6 +51,17 @@ export function LLMProviderUpdateForm({
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
|
||||
// Track dynamically fetched model configurations
|
||||
const [dynamicModelConfigurations, setDynamicModelConfigurations] = useState<
|
||||
{ name: string; is_visible: boolean }[]
|
||||
>([]);
|
||||
|
||||
// Combine original and dynamically fetched model configurations
|
||||
const allModelConfigurations = [
|
||||
...llmProviderDescriptor.model_configurations,
|
||||
...dynamicModelConfigurations,
|
||||
];
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name:
|
||||
@@ -69,7 +80,7 @@ export function LLMProviderUpdateForm({
|
||||
existingLlmProvider?.custom_config ??
|
||||
llmProviderDescriptor.custom_config_keys?.reduce(
|
||||
(acc, customConfigKey) => {
|
||||
acc[customConfigKey.name] = "";
|
||||
acc[customConfigKey.name] = customConfigKey.default_value || "";
|
||||
return acc;
|
||||
},
|
||||
{} as { [key: string]: string }
|
||||
@@ -148,11 +159,28 @@ export function LLMProviderUpdateForm({
|
||||
...rest
|
||||
} = values;
|
||||
|
||||
// Transform api_base for Llama Stack to append /v1/openai/v1 if not already present
|
||||
let llamaStackOpenAiApiBase =
|
||||
rest.custom_config?.LLAMA_STACK_SERVER_URL;
|
||||
if (
|
||||
llmProviderDescriptor.display_name === "Llama Stack" &&
|
||||
rest.custom_config?.LLAMA_STACK_SERVER_URL
|
||||
) {
|
||||
const baseUrl = rest.custom_config?.LLAMA_STACK_SERVER_URL.replace(
|
||||
/\/+$/,
|
||||
""
|
||||
); // Remove trailing slashes
|
||||
if (!baseUrl.endsWith("/v1/openai/v1")) {
|
||||
llamaStackOpenAiApiBase = `${baseUrl}/v1/openai/v1`;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the final payload with proper typing
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: llamaStackOpenAiApiBase,
|
||||
api_key_changed: values.api_key !== initialValues.api_key,
|
||||
model_configurations: llmProviderDescriptor.model_configurations.map(
|
||||
model_configurations: allModelConfigurations.map(
|
||||
(modelConfiguration): ModelConfigurationUpsertRequest => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
@@ -330,23 +358,112 @@ export function LLMProviderUpdateForm({
|
||||
}
|
||||
})}
|
||||
|
||||
{llmProviderDescriptor.display_name === "Llama Stack" && (
|
||||
<div className="my-4">
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onClick={async () => {
|
||||
const serverUrl =
|
||||
formikProps.values.custom_config?.LLAMA_STACK_SERVER_URL;
|
||||
if (!serverUrl) {
|
||||
setPopup?.({
|
||||
type: "error",
|
||||
message: "Please enter the Llama Stack Server URL first.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/admin/llm/llama-stack-models",
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
server_url: serverUrl,
|
||||
model_type: "llm",
|
||||
}),
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch models from backend.");
|
||||
}
|
||||
const data = await response.json();
|
||||
if (!Array.isArray(data.models)) {
|
||||
throw new Error("Invalid response from backend.");
|
||||
}
|
||||
// Map models to the expected format for model_configurations
|
||||
console.log("data.models", data.models);
|
||||
const newModelConfigurations = data.models.map(
|
||||
(model: any) => ({
|
||||
name: model.identifier,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
})
|
||||
);
|
||||
|
||||
// Update the dynamic model configurations state
|
||||
setDynamicModelConfigurations(newModelConfigurations);
|
||||
|
||||
// Update form fields
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
newModelConfigurations
|
||||
);
|
||||
|
||||
// Update selected model names to include all fetched models by default
|
||||
const modelNames = newModelConfigurations.map(
|
||||
(config: { name: string; is_visible: boolean }) =>
|
||||
config.name
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
modelNames
|
||||
);
|
||||
|
||||
// Set the first model as default if no default is set
|
||||
if (
|
||||
!formikProps.values.default_model_name &&
|
||||
modelNames.length > 0
|
||||
) {
|
||||
formikProps.setFieldValue(
|
||||
"default_model_name",
|
||||
modelNames[0]
|
||||
);
|
||||
}
|
||||
|
||||
setPopup?.({
|
||||
type: "success",
|
||||
message: `Fetched ${newModelConfigurations.length} models from Llama Stack server.`,
|
||||
});
|
||||
} catch (err: any) {
|
||||
setPopup?.({
|
||||
type: "error",
|
||||
message: err.message || "Error fetching models.",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Fetch available models
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!firstTimeConfiguration && (
|
||||
<>
|
||||
<Separator />
|
||||
|
||||
{llmProviderDescriptor.model_configurations.length > 0 ? (
|
||||
{allModelConfigurations.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
options={allModelConfigurations.map((modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
}))}
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
@@ -367,14 +484,14 @@ export function LLMProviderUpdateForm({
|
||||
)}
|
||||
|
||||
{!llmProviderDescriptor.single_model_supported &&
|
||||
(llmProviderDescriptor.model_configurations.length > 0 ? (
|
||||
(allModelConfigurations.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
options={allModelConfigurations.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
@@ -404,7 +521,7 @@ export function LLMProviderUpdateForm({
|
||||
/>
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{llmProviderDescriptor.model_configurations.length > 0 && (
|
||||
{allModelConfigurations.length > 0 && (
|
||||
<div className="w-full">
|
||||
<MultiSelectField
|
||||
selectedInitially={
|
||||
@@ -413,7 +530,7 @@ export function LLMProviderUpdateForm({
|
||||
name="selected_model_names"
|
||||
label="Display Models"
|
||||
subtext="Select the models to make available to users. Unselected models will not be available."
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
options={allModelConfigurations.map(
|
||||
(modelConfiguration) => ({
|
||||
value: modelConfiguration.name,
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
|
||||
@@ -5,6 +5,7 @@ export interface CustomConfigKey {
|
||||
is_required: boolean;
|
||||
is_secret: boolean;
|
||||
key_type: "text_input" | "file_input";
|
||||
default_value?: string | null;
|
||||
}
|
||||
|
||||
export interface ModelConfigurationUpsertRequest {
|
||||
|
||||
@@ -130,7 +130,8 @@ export function EmbeddingModelSelection({
|
||||
<ProviderCreationModal
|
||||
updateCurrentModel={updateCurrentModel}
|
||||
isProxy={
|
||||
showTentativeProvider.provider_type == EmbeddingProvider.LITELLM
|
||||
showTentativeProvider.provider_type == EmbeddingProvider.LITELLM ||
|
||||
showTentativeProvider.provider_type == EmbeddingProvider.LLAMA_STACK
|
||||
}
|
||||
isAzure={
|
||||
showTentativeProvider.provider_type == EmbeddingProvider.AZURE
|
||||
@@ -153,7 +154,10 @@ export function EmbeddingModelSelection({
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
isProxy={
|
||||
changeCredentialsProvider.provider_type == EmbeddingProvider.LITELLM
|
||||
changeCredentialsProvider.provider_type ==
|
||||
EmbeddingProvider.LITELLM ||
|
||||
changeCredentialsProvider.provider_type ==
|
||||
EmbeddingProvider.LLAMA_STACK
|
||||
}
|
||||
isAzure={
|
||||
changeCredentialsProvider.provider_type == EmbeddingProvider.AZURE
|
||||
|
||||
@@ -18,6 +18,7 @@ export interface RerankingDetails {
|
||||
export enum RerankerProvider {
|
||||
COHERE = "cohere",
|
||||
LITELLM = "litellm",
|
||||
LLAMA_STACK = "llama_stack",
|
||||
BEDROCK = "bedrock",
|
||||
}
|
||||
|
||||
|
||||
@@ -113,6 +113,18 @@ export function ProviderCreationModal({
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
const providerType = values.provider_type.toLowerCase().split(" ")[0];
|
||||
const isOpenAI = providerType === "openai";
|
||||
const isLlamaStack = providerType === "llama_stack";
|
||||
|
||||
// Normalize API URL for Llama Stack
|
||||
let normalizedApiUrl = values.api_url;
|
||||
if (isLlamaStack && normalizedApiUrl) {
|
||||
// Remove trailing slash if present
|
||||
normalizedApiUrl = normalizedApiUrl.replace(/\/$/, "");
|
||||
// Append the normalized endpoint if not already present
|
||||
if (!normalizedApiUrl.endsWith("/v1/openai/v1/embeddings")) {
|
||||
normalizedApiUrl = `${normalizedApiUrl}/v1/openai/v1/embeddings`;
|
||||
}
|
||||
}
|
||||
|
||||
const testModelName =
|
||||
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
|
||||
@@ -120,7 +132,7 @@ export function ProviderCreationModal({
|
||||
const testEmbeddingPayload = {
|
||||
provider_type: providerType,
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
api_url: normalizedApiUrl,
|
||||
model_name: testModelName,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
@@ -148,6 +160,7 @@ export function ProviderCreationModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
api_url: normalizedApiUrl, // Use normalized API URL
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
@@ -222,7 +235,12 @@ export function ProviderCreationModal({
|
||||
<TextFormField
|
||||
name="api_url"
|
||||
label="API URL"
|
||||
placeholder="API URL"
|
||||
placeholder={
|
||||
selectedProvider.provider_type ===
|
||||
EmbeddingProvider.LLAMA_STACK
|
||||
? "Your Llama Stack server base URL (e.g. http://0.0.0.0:8321)"
|
||||
: "API URL"
|
||||
}
|
||||
type="text"
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
LITELLM_CLOUD_PROVIDER,
|
||||
LLAMA_STACK_CLOUD_PROVIDER,
|
||||
AZURE_CLOUD_PROVIDER,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
|
||||
@@ -69,6 +70,10 @@ export default function CloudEmbeddingPage({
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
|
||||
const [llamaStackProvider, setLlamaStackProvider] = useState<
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
|
||||
const [azureProvider, setAzureProvider] = useState<
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
@@ -79,6 +84,11 @@ export default function CloudEmbeddingPage({
|
||||
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
|
||||
);
|
||||
setLiteLLMProvider(liteLLMProvider);
|
||||
const llamaStackProvider = embeddingProviderDetails?.find(
|
||||
(provider) =>
|
||||
provider.provider_type === EmbeddingProvider.LLAMA_STACK.toLowerCase()
|
||||
);
|
||||
setLlamaStackProvider(llamaStackProvider);
|
||||
const azureProvider = embeddingProviderDetails?.find(
|
||||
(provider) =>
|
||||
provider.provider_type === EmbeddingProvider.AZURE.toLowerCase()
|
||||
@@ -154,6 +164,133 @@ export default function CloudEmbeddingPage({
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<Text className="mt-6">
|
||||
You can set up a Llama Stack server to use self-hosted and cloud-based
|
||||
models.{" "}
|
||||
<a
|
||||
href="https://llama-stack.readthedocs.io/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-500 hover:underline"
|
||||
>
|
||||
Learn more about Llama Stack
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div
|
||||
key={LLAMA_STACK_CLOUD_PROVIDER.provider_type}
|
||||
className="mt-4 w-full"
|
||||
>
|
||||
<div className="flex items-center mb-2">
|
||||
{LLAMA_STACK_CLOUD_PROVIDER.icon({ size: 40 })}
|
||||
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||
{LLAMA_STACK_CLOUD_PROVIDER.provider_type}{" "}
|
||||
{LLAMA_STACK_CLOUD_PROVIDER.provider_type ==
|
||||
EmbeddingProvider.COHERE && "(recommended)"}
|
||||
</h2>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-sm text-text-800 w-52">
|
||||
<div className="my-auto">
|
||||
{LLAMA_STACK_CLOUD_PROVIDER.description}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
style="dark"
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full flex flex-col items-start">
|
||||
{!llamaStackProvider ? (
|
||||
<button
|
||||
onClick={() =>
|
||||
setShowTentativeProvider(LLAMA_STACK_CLOUD_PROVIDER)
|
||||
}
|
||||
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
|
||||
>
|
||||
Set API Configuration
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
onClick={() =>
|
||||
setChangeCredentialsProvider(LLAMA_STACK_CLOUD_PROVIDER)
|
||||
}
|
||||
className="mb-2 hover:underline text-sm cursor-pointer"
|
||||
>
|
||||
Modify API Configuration
|
||||
</button>
|
||||
)}
|
||||
|
||||
{!llamaStackProvider && (
|
||||
<CardSection className="mt-2 w-full max-w-4xl bg-background-50 border border-background-200">
|
||||
<div className="p-4">
|
||||
<Text className="text-lg font-semibold mb-2">
|
||||
API URL Required
|
||||
</Text>
|
||||
<Text className="text-sm text-text-600 mb-4">
|
||||
Before you can add models, you need to provide an API URL
|
||||
for your Llama Stack server. Click the "Set API
|
||||
Configuration" button above to set up your Llama Stack
|
||||
configuration.
|
||||
</Text>
|
||||
<div className="flex items-center">
|
||||
<FiInfo className="text-blue-500 mr-2" size={18} />
|
||||
<Text className="text-sm text-blue-500">
|
||||
Once configured, you'll be able to add and manage
|
||||
your Llama Stack models here.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
)}
|
||||
{llamaStackProvider && (
|
||||
<>
|
||||
<div className="flex mb-4 flex-wrap gap-4">
|
||||
{embeddingModelDetails
|
||||
?.filter(
|
||||
(model) =>
|
||||
model.provider_type ===
|
||||
EmbeddingProvider.LLAMA_STACK.toLowerCase()
|
||||
)
|
||||
.map((model) => (
|
||||
<CloudModelCard
|
||||
key={model.model_name}
|
||||
model={model}
|
||||
provider={LLAMA_STACK_CLOUD_PROVIDER}
|
||||
currentModel={currentModel}
|
||||
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
setShowModelInQueue={setShowModelInQueue}
|
||||
setShowTentativeProvider={setShowTentativeProvider}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<CardSection
|
||||
className={`mt-2 w-full max-w-4xl ${
|
||||
currentModel.provider_type === EmbeddingProvider.LLAMA_STACK
|
||||
? "border-2 border-blue-500"
|
||||
: ""
|
||||
}`}
|
||||
>
|
||||
<CustomEmbeddingModelForm
|
||||
embeddingType={EmbeddingProvider.LLAMA_STACK}
|
||||
provider={llamaStackProvider}
|
||||
currentValues={
|
||||
currentModel.provider_type ===
|
||||
EmbeddingProvider.LLAMA_STACK
|
||||
? (currentModel as CloudEmbeddingModel)
|
||||
: null
|
||||
}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
/>
|
||||
</CardSection>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, you can use a self-hosted model using the LiteLLM
|
||||
|
||||
@@ -57,15 +57,17 @@ export function CustomEmbeddingModelForm({
|
||||
<Form>
|
||||
<Text className="text-xl text-text-900 font-bold mb-4">
|
||||
Specify details for your{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE
|
||||
? "Azure"
|
||||
: embeddingType === EmbeddingProvider.LLAMA_STACK
|
||||
? "Llama Stack"
|
||||
: "LiteLLM"}{" "}
|
||||
Provider's model
|
||||
</Text>
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label="Model Name:"
|
||||
subtext={`The name of the ${
|
||||
embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"
|
||||
} model`}
|
||||
subtext={`The name of the embedding model`}
|
||||
placeholder="e.g. 'all-MiniLM-L6-v2'"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
@@ -106,7 +108,11 @@ export function CustomEmbeddingModelForm({
|
||||
className="w-64 mx-auto"
|
||||
>
|
||||
Configure{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE
|
||||
? "Azure"
|
||||
: embeddingType === EmbeddingProvider.LLAMA_STACK
|
||||
? "Llama Stack"
|
||||
: "LiteLLM"}{" "}
|
||||
Model
|
||||
</Button>
|
||||
</Form>
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
GoogleIcon,
|
||||
IconProps,
|
||||
LiteLLMIcon,
|
||||
LlamaStackIcon,
|
||||
MicrosoftIcon,
|
||||
NomicIcon,
|
||||
OpenAIISVG,
|
||||
@@ -17,6 +18,7 @@ export enum EmbeddingProvider {
|
||||
VOYAGE = "voyage",
|
||||
GOOGLE = "google",
|
||||
LITELLM = "litellm",
|
||||
LLAMA_STACK = "llama_stack",
|
||||
AZURE = "azure",
|
||||
}
|
||||
|
||||
@@ -159,6 +161,15 @@ export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
embedding_models: [], // No default embedding models
|
||||
};
|
||||
|
||||
export const LLAMA_STACK_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
provider_type: EmbeddingProvider.LLAMA_STACK,
|
||||
website: "https://llama-stack.readthedocs.io/",
|
||||
icon: LlamaStackIcon,
|
||||
description: "Llama Stack inference server for running Llama models",
|
||||
apiLink: "https://llama-stack.readthedocs.io/",
|
||||
embedding_models: [], // No default embedding models
|
||||
};
|
||||
|
||||
export const AZURE_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
provider_type: EmbeddingProvider.AZURE,
|
||||
website:
|
||||
|
||||
@@ -44,6 +44,7 @@ import mistralSVG from "../../../public/Mistral.svg";
|
||||
import qwenSVG from "../../../public/Qwen.svg";
|
||||
import openSourceIcon from "../../../public/OpenSource.png";
|
||||
import litellmIcon from "../../../public/litellm.png";
|
||||
import llamaStackIcon from "../../../public/Meta.png"; // Using OpenSource icon as placeholder
|
||||
import azureIcon from "../../../public/Azure.png";
|
||||
import asanaIcon from "../../../public/Asana.png";
|
||||
import anthropicSVG from "../../../public/Anthropic.svg";
|
||||
@@ -283,6 +284,13 @@ export const LiteLLMIcon = ({
|
||||
return <LogoIcon size={size} className={className} src={litellmIcon} />;
|
||||
};
|
||||
|
||||
export const LlamaStackIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return <LogoIcon size={size} className={className} src={llamaStackIcon} />;
|
||||
};
|
||||
|
||||
export const OpenSourceIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
||||
Reference in New Issue
Block a user