mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 00:42:47 +00:00
Compare commits
2 Commits
cli/v0.2.1
...
feat/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17dc64148d | ||
|
|
1c27d414f1 |
34
backend/onyx/llm/prompt_cache/__init__.py
Normal file
34
backend/onyx/llm/prompt_cache/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Prompt caching framework for LLM providers.
|
||||
|
||||
This module provides a framework for enabling prompt caching across different
|
||||
LLM providers. It supports both implicit caching (automatic provider-side caching)
|
||||
and explicit caching (with cache metadata management).
|
||||
"""
|
||||
|
||||
from onyx.llm.prompt_cache.cache_manager import CacheManager
|
||||
from onyx.llm.prompt_cache.cache_manager import generate_cache_key_hash
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.factory import get_provider_adapter
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import combine_messages_with_continuation
|
||||
from onyx.llm.prompt_cache.utils import normalize_language_model_input
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCacheProvider",
|
||||
"CacheManager",
|
||||
"CacheMetadata",
|
||||
"combine_messages_with_continuation",
|
||||
"generate_cache_key_hash",
|
||||
"get_provider_adapter",
|
||||
"normalize_language_model_input",
|
||||
"NoOpPromptCacheProvider",
|
||||
"OpenAIPromptCacheProvider",
|
||||
"prepare_messages_with_cacheable_transform",
|
||||
"PromptCacheProvider",
|
||||
"VertexAIPromptCacheProvider",
|
||||
]
|
||||
199
backend/onyx/llm/prompt_cache/cache_manager.py
Normal file
199
backend/onyx/llm/prompt_cache/cache_manager.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Cache manager for storing and retrieving prompt cache metadata."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.utils import normalize_language_model_input
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REDIS_KEY_PREFIX = "prompt_cache:"
|
||||
# Cache TTL multiplier - store caches slightly longer than provider TTL
|
||||
# This allows for some clock skew and ensures we don't lose cache metadata prematurely
|
||||
CACHE_TTL_MULTIPLIER = 1.2
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Manages storage and retrieval of prompt cache metadata."""
|
||||
|
||||
def __init__(self, kv_store: PgRedisKVStore | None = None) -> None:
|
||||
"""Initialize the cache manager.
|
||||
|
||||
Args:
|
||||
kv_store: Optional key-value store. If None, creates a new PgRedisKVStore.
|
||||
"""
|
||||
self._kv_store = kv_store or PgRedisKVStore()
|
||||
|
||||
def _build_cache_key(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
cache_key_hash: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
"""Build a Redis/PostgreSQL key for cache metadata.
|
||||
|
||||
Args:
|
||||
provider: LLM provider name (e.g., "openai", "anthropic")
|
||||
model_name: Model name
|
||||
cache_key_hash: Hash of the cacheable prefix content
|
||||
tenant_id: Tenant ID. If None, uses current tenant from context.
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
return f"{REDIS_KEY_PREFIX}{tenant_id}:{provider}:{model_name}:{cache_key_hash}"
|
||||
|
||||
def store_cache_metadata(
|
||||
self,
|
||||
metadata: CacheMetadata,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""Store cache metadata.
|
||||
|
||||
Args:
|
||||
metadata: Cache metadata to store
|
||||
ttl_seconds: Optional TTL in seconds. If None, uses provider default.
|
||||
"""
|
||||
try:
|
||||
cache_key = self._build_cache_key(
|
||||
metadata.provider,
|
||||
metadata.model_name,
|
||||
metadata.cache_key,
|
||||
metadata.tenant_id,
|
||||
)
|
||||
|
||||
# Update last_accessed timestamp
|
||||
metadata.last_accessed = datetime.now(timezone.utc)
|
||||
|
||||
# Serialize metadata
|
||||
metadata_dict = metadata.model_dump(mode="json")
|
||||
|
||||
# Store in key-value store
|
||||
# Note: PgRedisKVStore doesn't support TTL directly, but Redis will
|
||||
# handle expiration. For PostgreSQL persistence, we rely on cleanup
|
||||
# based on last_accessed timestamp.
|
||||
self._kv_store.store(cache_key, metadata_dict, encrypt=False)
|
||||
|
||||
logger.debug(
|
||||
f"Stored cache metadata for provider={metadata.provider}, "
|
||||
f"model={metadata.model_name}, cache_key={metadata.cache_key[:16]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
# Best-effort: log and continue
|
||||
logger.warning(f"Failed to store cache metadata: {str(e)}")
|
||||
|
||||
def retrieve_cache_metadata(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
cache_key_hash: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> CacheMetadata | None:
|
||||
"""Retrieve cache metadata.
|
||||
|
||||
Args:
|
||||
provider: LLM provider name
|
||||
model_name: Model name
|
||||
cache_key_hash: Hash of the cacheable prefix content
|
||||
tenant_id: Tenant ID. If None, uses current tenant from context.
|
||||
|
||||
Returns:
|
||||
CacheMetadata if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
cache_key = self._build_cache_key(
|
||||
provider, model_name, cache_key_hash, tenant_id
|
||||
)
|
||||
metadata_dict = self._kv_store.load(cache_key, refresh_cache=False)
|
||||
|
||||
# Deserialize metadata
|
||||
metadata = CacheMetadata.model_validate(metadata_dict)
|
||||
|
||||
# Update last_accessed timestamp
|
||||
metadata.last_accessed = datetime.now(timezone.utc)
|
||||
self.store_cache_metadata(metadata)
|
||||
|
||||
logger.debug(
|
||||
f"Retrieved cache metadata for provider={provider}, "
|
||||
f"model={model_name}, cache_key={cache_key_hash[:16]}..."
|
||||
)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
# Best-effort: log and continue
|
||||
logger.debug(f"Cache metadata not found or error retrieving: {str(e)}")
|
||||
return None
|
||||
|
||||
def delete_cache_metadata(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
cache_key_hash: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
"""Delete cache metadata.
|
||||
|
||||
Args:
|
||||
provider: LLM provider name
|
||||
model_name: Model name
|
||||
cache_key_hash: Hash of the cacheable prefix content
|
||||
tenant_id: Tenant ID. If None, uses current tenant from context.
|
||||
"""
|
||||
try:
|
||||
cache_key = self._build_cache_key(
|
||||
provider, model_name, cache_key_hash, tenant_id
|
||||
)
|
||||
self._kv_store.delete(cache_key)
|
||||
logger.debug(
|
||||
f"Deleted cache metadata for provider={provider}, "
|
||||
f"model={model_name}, cache_key={cache_key_hash[:16]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
# Best-effort: log and continue
|
||||
logger.warning(f"Failed to delete cache metadata: {str(e)}")
|
||||
|
||||
|
||||
def generate_cache_key_hash(
|
||||
cacheable_prefix: LanguageModelInput,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
tenant_id: str,
|
||||
) -> str:
|
||||
"""Generate a deterministic cache key hash from cacheable prefix.
|
||||
|
||||
Args:
|
||||
cacheable_prefix: LanguageModelInput (str or Sequence[ChatCompletionMessage])
|
||||
provider: LLM provider name
|
||||
model_name: Model name
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
SHA256 hash as hex string
|
||||
"""
|
||||
# Normalize to Sequence[ChatCompletionMessage] for consistent hashing
|
||||
messages = normalize_language_model_input(cacheable_prefix)
|
||||
# Convert to list of dicts for serialization
|
||||
messages_dict = [dict(msg) for msg in messages]
|
||||
|
||||
# Serialize messages in a deterministic way
|
||||
# Include only content, roles, and order - exclude timestamps or dynamic fields
|
||||
serialized = json.dumps(
|
||||
{
|
||||
"messages": messages_dict,
|
||||
"provider": provider,
|
||||
"model": model_name,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()
|
||||
20
backend/onyx/llm/prompt_cache/interfaces.py
Normal file
20
backend/onyx/llm/prompt_cache/interfaces.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Interfaces and data structures for prompt caching."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CacheMetadata(BaseModel):
|
||||
"""Metadata for cached prompt prefixes."""
|
||||
|
||||
cache_key: str
|
||||
provider: str
|
||||
model_name: str
|
||||
tenant_id: str
|
||||
created_at: datetime
|
||||
last_accessed: datetime
|
||||
# Provider-specific metadata
|
||||
# TODO: Add explicit caching support in future PR
|
||||
# vertex_block_numbers: dict[str, str] | None = None # message_hash -> block_number
|
||||
# anthropic_cache_id: str | None = None
|
||||
17
backend/onyx/llm/prompt_cache/providers/__init__.py
Normal file
17
backend/onyx/llm/prompt_cache/providers/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Provider adapters for prompt caching."""
|
||||
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.factory import get_provider_adapter
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCacheProvider",
|
||||
"get_provider_adapter",
|
||||
"NoOpPromptCacheProvider",
|
||||
"OpenAIPromptCacheProvider",
|
||||
"PromptCacheProvider",
|
||||
"VertexAIPromptCacheProvider",
|
||||
]
|
||||
90
backend/onyx/llm/prompt_cache/providers/anthropic.py
Normal file
90
backend/onyx/llm/prompt_cache/providers/anthropic.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Anthropic provider adapter for prompt caching."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
|
||||
def _add_anthropic_cache_control(
|
||||
messages: Sequence[ChatCompletionMessage],
|
||||
) -> Sequence[ChatCompletionMessage]:
|
||||
"""Add cache_control parameter to messages for Anthropic caching.
|
||||
|
||||
Args:
|
||||
messages: Messages to transform
|
||||
|
||||
Returns:
|
||||
Messages with cache_control added
|
||||
"""
|
||||
cacheable_messages: list[ChatCompletionMessage] = []
|
||||
for msg in messages:
|
||||
msg_dict = dict(msg)
|
||||
# Add cache_control parameter
|
||||
# Anthropic supports up to 4 cache breakpoints
|
||||
msg_dict["cache_control"] = {"type": "ephemeral"}
|
||||
cacheable_messages.append(msg_dict) # type: ignore
|
||||
return cacheable_messages
|
||||
|
||||
|
||||
class AnthropicPromptCacheProvider(PromptCacheProvider):
|
||||
"""Anthropic adapter for prompt caching (explicit caching with cache_control)."""
|
||||
|
||||
def supports_caching(self) -> bool:
|
||||
"""Anthropic supports explicit prompt caching."""
|
||||
return True
|
||||
|
||||
def prepare_messages_for_caching(
|
||||
self,
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
cache_metadata: CacheMetadata | None,
|
||||
) -> LanguageModelInput:
|
||||
"""Prepare messages for Anthropic caching.
|
||||
|
||||
Anthropic requires cache_control parameter on cacheable messages.
|
||||
We add cache_control={"type": "ephemeral"} to all cacheable prefix messages.
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix
|
||||
suffix: Non-cacheable suffix
|
||||
continuation: Whether to append suffix to last prefix message
|
||||
cache_metadata: Cache metadata (for future explicit caching support)
|
||||
|
||||
Returns:
|
||||
Combined messages with cache_control on cacheable messages
|
||||
"""
|
||||
return prepare_messages_with_cacheable_transform(
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=_add_anthropic_cache_control,
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
self,
|
||||
response: dict,
|
||||
cache_key: str,
|
||||
) -> CacheMetadata | None:
|
||||
"""Extract cache metadata from Anthropic response.
|
||||
|
||||
Anthropic may return cache identifiers in the response.
|
||||
For now, we don't extract detailed metadata (future explicit caching support).
|
||||
|
||||
Args:
|
||||
response: Anthropic API response dictionary
|
||||
cache_key: Cache key used for this request
|
||||
|
||||
Returns:
|
||||
CacheMetadata if extractable, None otherwise
|
||||
"""
|
||||
# TODO: Extract cache identifiers from response when implementing explicit caching
|
||||
return None
|
||||
|
||||
def get_cache_ttl_seconds(self) -> int:
|
||||
"""Get cache TTL for Anthropic (5 minutes default)."""
|
||||
return 300
|
||||
70
backend/onyx/llm/prompt_cache/providers/base.py
Normal file
70
backend/onyx/llm/prompt_cache/providers/base.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Base interface for provider-specific prompt caching adapters."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
|
||||
|
||||
class PromptCacheProvider(ABC):
|
||||
"""Abstract base class for provider-specific prompt caching logic."""
|
||||
|
||||
@abstractmethod
|
||||
def supports_caching(self) -> bool:
|
||||
"""Whether this provider supports prompt caching.
|
||||
|
||||
Returns:
|
||||
True if caching is supported, False otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_messages_for_caching(
|
||||
self,
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
cache_metadata: CacheMetadata | None,
|
||||
) -> LanguageModelInput:
|
||||
"""Transform messages to enable caching.
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix (can be str or Sequence[ChatCompletionMessage])
|
||||
suffix: Non-cacheable suffix (can be str or Sequence[ChatCompletionMessage])
|
||||
continuation: If True, suffix should be appended to the last message
|
||||
of cacheable_prefix rather than being separate messages.
|
||||
Note: When cacheable_prefix is a string, it should remain in its own
|
||||
content block even if continuation=True.
|
||||
cache_metadata: Optional cache metadata from previous requests
|
||||
|
||||
Returns:
|
||||
Combined and transformed messages ready for LLM API call
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract_cache_metadata(
|
||||
self,
|
||||
response: dict, # Provider-specific response object
|
||||
cache_key: str,
|
||||
) -> CacheMetadata | None:
|
||||
"""Extract cache metadata from API response.
|
||||
|
||||
Args:
|
||||
response: Provider-specific response dictionary
|
||||
cache_key: Cache key used for this request
|
||||
|
||||
Returns:
|
||||
CacheMetadata if extractable, None otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_ttl_seconds(self) -> int:
|
||||
"""Get cache TTL in seconds for this provider.
|
||||
|
||||
Returns:
|
||||
TTL in seconds
|
||||
"""
|
||||
raise NotImplementedError
|
||||
30
backend/onyx/llm/prompt_cache/providers/factory.py
Normal file
30
backend/onyx/llm/prompt_cache/providers/factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Factory for creating provider-specific prompt cache adapters."""
|
||||
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import VERTEXAI_PROVIDER_NAME
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
|
||||
|
||||
def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
"""Get the appropriate prompt cache provider adapter for a given provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "openai", "anthropic", "vertex_ai")
|
||||
|
||||
Returns:
|
||||
PromptCacheProvider instance for the given provider
|
||||
"""
|
||||
if provider == OPENAI_PROVIDER_NAME:
|
||||
return OpenAIPromptCacheProvider()
|
||||
elif provider == ANTHROPIC_PROVIDER_NAME:
|
||||
return AnthropicPromptCacheProvider()
|
||||
elif provider == VERTEXAI_PROVIDER_NAME:
|
||||
return VertexAIPromptCacheProvider()
|
||||
else:
|
||||
# Default to no-op for providers without caching support
|
||||
return NoOpPromptCacheProvider()
|
||||
53
backend/onyx/llm/prompt_cache/providers/noop.py
Normal file
53
backend/onyx/llm/prompt_cache/providers/noop.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""No-op provider adapter for providers without caching support."""
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
|
||||
class NoOpPromptCacheProvider(PromptCacheProvider):
|
||||
"""No-op adapter for providers that don't support prompt caching."""
|
||||
|
||||
def supports_caching(self) -> bool:
|
||||
"""No-op providers don't support caching."""
|
||||
return False
|
||||
|
||||
def prepare_messages_for_caching(
|
||||
self,
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
cache_metadata: CacheMetadata | None,
|
||||
) -> LanguageModelInput:
|
||||
"""Return messages unchanged (no caching support).
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix (can be str or Sequence[ChatCompletionMessage])
|
||||
suffix: Non-cacheable suffix (can be str or Sequence[ChatCompletionMessage])
|
||||
continuation: Whether to append suffix to last prefix message.
|
||||
Note: When cacheable_prefix is a string, it remains in its own content block.
|
||||
cache_metadata: Cache metadata (ignored)
|
||||
|
||||
Returns:
|
||||
Combined messages (prefix + suffix)
|
||||
"""
|
||||
# No transformation needed for no-op provider
|
||||
return prepare_messages_with_cacheable_transform(
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=None,
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
self,
|
||||
response: dict,
|
||||
cache_key: str,
|
||||
) -> CacheMetadata | None:
|
||||
"""No cache metadata to extract."""
|
||||
return None
|
||||
|
||||
def get_cache_ttl_seconds(self) -> int:
|
||||
"""Return default TTL (not used for no-op)."""
|
||||
return 0
|
||||
69
backend/onyx/llm/prompt_cache/providers/openai.py
Normal file
69
backend/onyx/llm/prompt_cache/providers/openai.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""OpenAI provider adapter for prompt caching."""
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
|
||||
class OpenAIPromptCacheProvider(PromptCacheProvider):
|
||||
"""OpenAI adapter for prompt caching (implicit caching)."""
|
||||
|
||||
def supports_caching(self) -> bool:
|
||||
"""OpenAI supports automatic prompt caching."""
|
||||
return True
|
||||
|
||||
def prepare_messages_for_caching(
|
||||
self,
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
cache_metadata: CacheMetadata | None,
|
||||
) -> LanguageModelInput:
|
||||
"""Prepare messages for OpenAI caching.
|
||||
|
||||
OpenAI handles caching automatically, so we just normalize and combine
|
||||
the messages. The provider will automatically cache prefixes >1024 tokens.
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix
|
||||
suffix: Non-cacheable suffix
|
||||
continuation: Whether to append suffix to last prefix message
|
||||
cache_metadata: Cache metadata (ignored for implicit caching)
|
||||
|
||||
Returns:
|
||||
Combined messages ready for LLM API call
|
||||
"""
|
||||
# No transformation needed for OpenAI (implicit caching)
|
||||
return prepare_messages_with_cacheable_transform(
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=None,
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
self,
|
||||
response: dict,
|
||||
cache_key: str,
|
||||
) -> CacheMetadata | None:
|
||||
"""Extract cache metadata from OpenAI response.
|
||||
|
||||
OpenAI responses may include cached_tokens in the usage field.
|
||||
For implicit caching, we don't need to store much metadata.
|
||||
|
||||
Args:
|
||||
response: OpenAI API response dictionary
|
||||
cache_key: Cache key used for this request
|
||||
|
||||
Returns:
|
||||
CacheMetadata if extractable, None otherwise
|
||||
"""
|
||||
# For implicit caching, OpenAI handles everything automatically
|
||||
# We could extract cached_tokens from response.get("usage", {}).get("cached_tokens")
|
||||
# but for now, we don't need to store metadata for implicit caching
|
||||
return None
|
||||
|
||||
def get_cache_ttl_seconds(self) -> int:
|
||||
"""Get cache TTL for OpenAI (1 hour max)."""
|
||||
return 3600
|
||||
80
backend/onyx/llm/prompt_cache/providers/vertex.py
Normal file
80
backend/onyx/llm/prompt_cache/providers/vertex.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Vertex AI provider adapter for prompt caching."""
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.interfaces import CacheMetadata
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
|
||||
class VertexAIPromptCacheProvider(PromptCacheProvider):
|
||||
"""Vertex AI adapter for prompt caching (implicit caching for this PR)."""
|
||||
|
||||
def supports_caching(self) -> bool:
|
||||
"""Vertex AI supports prompt caching (implicit and explicit)."""
|
||||
return True
|
||||
|
||||
def prepare_messages_for_caching(
|
||||
self,
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
cache_metadata: CacheMetadata | None,
|
||||
) -> LanguageModelInput:
|
||||
"""Prepare messages for Vertex AI caching.
|
||||
|
||||
For this PR, we only implement implicit caching (automatic, similar to OpenAI).
|
||||
Vertex handles implicit caching automatically, so we just normalize and combine.
|
||||
|
||||
TODO (explicit caching - future PR):
|
||||
- If cache_metadata exists and has vertex_block_numbers: Replace message content
|
||||
with {"cache_block_id": "<block_number>"}
|
||||
- If not: Add cache_control={"type": "ephemeral"} to cacheable messages
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix
|
||||
suffix: Non-cacheable suffix
|
||||
continuation: Whether to append suffix to last prefix message
|
||||
cache_metadata: Cache metadata (for future explicit caching support)
|
||||
|
||||
Returns:
|
||||
Combined messages ready for LLM API call
|
||||
"""
|
||||
# For implicit caching, no transformation needed (Vertex handles caching automatically)
|
||||
# TODO (explicit caching - future PR):
|
||||
# - Check cache_metadata for vertex_block_numbers
|
||||
# - Create transform function that replaces messages with cache_block_id if available
|
||||
# - Or adds cache_control parameter if not using cached blocks
|
||||
return prepare_messages_with_cacheable_transform(
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=None,
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
self,
|
||||
response: dict,
|
||||
cache_key: str,
|
||||
) -> CacheMetadata | None:
|
||||
"""Extract cache metadata from Vertex AI response.
|
||||
|
||||
For this PR (implicit caching): Extract basic cache usage info if available.
|
||||
TODO (explicit caching - future PR): Extract block numbers from response
|
||||
and store in metadata.
|
||||
|
||||
Args:
|
||||
response: Vertex AI API response dictionary
|
||||
cache_key: Cache key used for this request
|
||||
|
||||
Returns:
|
||||
CacheMetadata if extractable, None otherwise
|
||||
"""
|
||||
# For implicit caching, Vertex handles everything automatically
|
||||
# TODO (explicit caching - future PR):
|
||||
# - Extract cache block numbers from response
|
||||
# - Store in cache_metadata.vertex_block_numbers
|
||||
return None
|
||||
|
||||
def get_cache_ttl_seconds(self) -> int:
|
||||
"""Get cache TTL for Vertex AI (5 minutes)."""
|
||||
return 300
|
||||
124
backend/onyx/llm/prompt_cache/utils.py
Normal file
124
backend/onyx/llm/prompt_cache/utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Utility functions for prompt caching."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
|
||||
|
||||
def normalize_language_model_input(
|
||||
input: LanguageModelInput,
|
||||
) -> Sequence[ChatCompletionMessage]:
|
||||
"""Normalize LanguageModelInput to Sequence[ChatCompletionMessage].
|
||||
|
||||
Args:
|
||||
input: LanguageModelInput (str or Sequence[ChatCompletionMessage])
|
||||
|
||||
Returns:
|
||||
Sequence of ChatCompletionMessage objects
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
# Convert string to user message
|
||||
return [UserMessageWithText(role="user", content=input)]
|
||||
else:
|
||||
return input
|
||||
|
||||
|
||||
def combine_messages_with_continuation(
|
||||
prefix_msgs: Sequence[ChatCompletionMessage],
|
||||
suffix_msgs: Sequence[ChatCompletionMessage],
|
||||
continuation: bool,
|
||||
was_prefix_string: bool,
|
||||
) -> Sequence[ChatCompletionMessage]:
|
||||
"""Combine prefix and suffix messages, handling continuation flag.
|
||||
|
||||
Args:
|
||||
prefix_msgs: Normalized cacheable prefix messages
|
||||
suffix_msgs: Normalized suffix messages
|
||||
continuation: If True and prefix is not a string, append suffix content
|
||||
to the last message of prefix
|
||||
was_prefix_string: Whether the original prefix was a string (strings
|
||||
remain in their own content block even if continuation=True)
|
||||
|
||||
Returns:
|
||||
Combined messages
|
||||
"""
|
||||
if continuation and prefix_msgs and not was_prefix_string:
|
||||
# Append suffix content to last message of prefix
|
||||
result = list(prefix_msgs)
|
||||
last_msg = dict(result[-1])
|
||||
suffix_first = dict(suffix_msgs[0]) if suffix_msgs else {}
|
||||
|
||||
# Combine content
|
||||
if "content" in last_msg and "content" in suffix_first:
|
||||
if isinstance(last_msg["content"], str) and isinstance(
|
||||
suffix_first["content"], str
|
||||
):
|
||||
last_msg["content"] = last_msg["content"] + suffix_first["content"]
|
||||
else:
|
||||
# Handle list content (multimodal)
|
||||
prefix_content = (
|
||||
last_msg["content"]
|
||||
if isinstance(last_msg["content"], list)
|
||||
else [{"type": "text", "text": last_msg["content"]}]
|
||||
)
|
||||
suffix_content = (
|
||||
suffix_first["content"]
|
||||
if isinstance(suffix_first["content"], list)
|
||||
else [{"type": "text", "text": suffix_first["content"]}]
|
||||
)
|
||||
last_msg["content"] = prefix_content + suffix_content
|
||||
|
||||
result[-1] = cast(ChatCompletionMessage, last_msg)
|
||||
result.extend(suffix_msgs[1:])
|
||||
return result
|
||||
|
||||
# Simple concatenation (or prefix was a string, so keep separate)
|
||||
return list(prefix_msgs) + list(suffix_msgs)
|
||||
|
||||
|
||||
def prepare_messages_with_cacheable_transform(
|
||||
cacheable_prefix: LanguageModelInput | None,
|
||||
suffix: LanguageModelInput,
|
||||
continuation: bool,
|
||||
transform_cacheable: (
|
||||
Callable[[Sequence[ChatCompletionMessage]], Sequence[ChatCompletionMessage]]
|
||||
| None
|
||||
) = None,
|
||||
) -> LanguageModelInput:
|
||||
"""Prepare messages for caching with optional transformation of cacheable prefix.
|
||||
|
||||
This is a shared utility that handles the common flow:
|
||||
1. Normalize inputs
|
||||
2. Optionally transform cacheable messages
|
||||
3. Combine with continuation handling
|
||||
|
||||
Args:
|
||||
cacheable_prefix: Optional cacheable prefix
|
||||
suffix: Non-cacheable suffix
|
||||
continuation: Whether to append suffix to last prefix message
|
||||
transform_cacheable: Optional function to transform cacheable messages
|
||||
(e.g., add cache_control parameter). If None, messages are used as-is.
|
||||
|
||||
Returns:
|
||||
Combined messages ready for LLM API call
|
||||
"""
|
||||
if cacheable_prefix is None:
|
||||
return suffix
|
||||
|
||||
prefix_msgs = normalize_language_model_input(cacheable_prefix)
|
||||
suffix_msgs = normalize_language_model_input(suffix)
|
||||
|
||||
# Apply transformation to cacheable messages if provided
|
||||
if transform_cacheable is not None:
|
||||
prefix_msgs = transform_cacheable(prefix_msgs)
|
||||
|
||||
# Handle continuation flag
|
||||
was_prefix_string = isinstance(cacheable_prefix, str)
|
||||
|
||||
return combine_messages_with_continuation(
|
||||
prefix_msgs, suffix_msgs, continuation, was_prefix_string
|
||||
)
|
||||
Reference in New Issue
Block a user