Compare commits

...

4 Commits

Author SHA1 Message Date
Wenxi Onyx
cd34eaf4af support reduced dimensions 2025-08-12 11:05:37 -07:00
Wenxi Onyx
a4d8bf8a64 reduced dim options 2025-08-12 11:05:37 -07:00
Wenxi Onyx
18ae2f54fe ai nits 2025-08-12 11:05:37 -07:00
Wenxi Onyx
7876515aac gemini embedding model 3072 2025-08-12 11:05:37 -07:00
6 changed files with 62 additions and 17 deletions

View File

@@ -61,6 +61,8 @@ _RETRY_TRIES = 10 if INDEXING_ONLY else 2
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# gemini-embedding-001 max batch size is 1
_GEMINI_EMBEDDING_MAX_BATCH_SIZE = 1
# Authentication error string constants
_AUTH_ERROR_401 = "401"
@@ -215,7 +217,11 @@ class CloudEmbedding:
return embeddings
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
self,
texts: list[str],
model: str | None,
embedding_type: str,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if not model:
model = DEFAULT_VERTEX_MODEL
@@ -229,17 +235,29 @@ class CloudEmbedding:
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
# gemini-embedding-001 max batch size is 1. otherwise use default
max_texts_per_batch = (
_GEMINI_EMBEDDING_MAX_BATCH_SIZE
if "gemini-embedding-001" in model
else VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
)
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
tasks = []
for batch in batches:
# Only pass output_dimensionality if reduced_dimension is specified
if reduced_dimension is not None:
task = client.get_embeddings_async(
batch, auto_truncate=True, output_dimensionality=reduced_dimension
)
else:
task = client.get_embeddings_async(batch, auto_truncate=True)
tasks.append(task)
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
@@ -295,7 +313,9 @@ class CloudEmbedding:
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
return await self._embed_vertex(
texts, model_name, embedding_type, reduced_dimension
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:

View File

@@ -61,6 +61,11 @@ _BASE_EMBEDDING_MODELS = [
dim=1536,
index_name="danswer_chunk_text_embedding_3_small",
),
_BaseEmbeddingModel(
name="google/gemini-embedding-001",
dim=3072,
index_name="danswer_chunk_google_gemini_embedding_001",
),
_BaseEmbeddingModel(
name="google/text-embedding-005",
dim=768,

View File

@@ -1526,7 +1526,7 @@ class SearchSettings(Base):
# section here:
# https://platform.openai.com/docs/guides/embeddings#embedding-models
# If not specified, will just use the model_dim without any reduction.
# NOTE: this is only currently available for OpenAI models
# NOTE: this is currently available for OpenAI and Google embedding models
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Mini and Large Chunks (large chunk also checks for model max context)

View File

@@ -33,7 +33,7 @@ class EmbedRequest(BaseModel):
api_version: str | None = None
# allows for the truncation of the vector to a lower dimension
# to reduce memory usage. Currently only supported for OpenAI models.
# to reduce memory usage. Currently supported for OpenAI and Google models.
# will be ignored for other providers.
reduced_dimension: int | None = None

View File

@@ -129,10 +129,14 @@ const AdvancedEmbeddingFormPage = forwardRef<
(value) => value === null || value === undefined || value >= 256
)
.test(
"openai",
"Reduced Dimensions is only supported for OpenAI embedding models",
"supported-dim-reduction-providers",
"Reduced Dimensions is only supported for OpenAI and Google embedding models",
(value) => {
return embeddingProviderType === "openai" || value === null;
return (
embeddingProviderType === "openai" ||
embeddingProviderType === "google" ||
value === null
);
}
),
})}
@@ -206,11 +210,13 @@ const AdvancedEmbeddingFormPage = forwardRef<
value === null || value === undefined || value >= 256
)
.test(
"openai",
"Reduced Dimensions is only supported for OpenAI embedding models",
"supported-dim-reduction-providers",
"Reduced Dimensions is only supported for OpenAI and Google embedding models",
(value) => {
return (
embeddingProviderType === "openai" || value === null
embeddingProviderType === "openai" ||
embeddingProviderType === "google" ||
value === null
);
}
),
@@ -354,7 +360,7 @@ const AdvancedEmbeddingFormPage = forwardRef<
description="Number of dimensions to reduce the embedding to.
Will reduce memory usage but may reduce accuracy.
If not specified, will just use the selected model's default dimensionality without any reduction.
Currently only supported for OpenAI embedding models"
Currently supported for OpenAI and Google embedding models"
optional={true}
label="Reduced Dimension"
name="reduced_dimension"

View File

@@ -265,10 +265,24 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
apiLink: "https://console.cloud.google.com/apis/credentials",
costslink: "https://cloud.google.com/vertex-ai/pricing",
embedding_models: [
{
provider_type: EmbeddingProvider.GOOGLE,
model_name: "gemini-embedding-001",
description:
'Google\'s latest text embedding model. Default is 3072 dimensions. Configure dimension reduction in "Advanced" tab. Recommended, smaller dimensions are 1536 or 768.',
pricePerMillion: 0.15,
model_dim: 3072,
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
{
provider_type: EmbeddingProvider.GOOGLE,
model_name: "text-embedding-005",
description: "Google's most recent text embedding model.",
description: "Google's previous generation text embedding model.",
pricePerMillion: 0.025,
model_dim: 768,
normalize: false,