CodeReview/backend/app/services/llm/prompt_cache.py

334 lines
10 KiB
Python
Raw Normal View History

"""
Prompt Caching 模块
为支持缓存的 LLM Anthropic Claude提供 Prompt 缓存功能
通过在系统提示词和早期对话中添加缓存标记减少重复处理
显著降低 Token 消耗和响应延迟
支持的 LLM:
- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku)
- OpenAI (部分模型支持)
缓存策略:
- 短对话<10: 仅缓存系统提示词
- 中等对话10-30: 缓存系统提示词 + 前5轮对话
- 长对话>30: 多个缓存点动态调整
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
class CacheStrategy(str, Enum):
"""缓存策略"""
NONE = "none" # 不缓存
SYSTEM_ONLY = "system_only" # 仅缓存系统提示词
SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话
MULTI_POINT = "multi_point" # 多缓存点
@dataclass
class CacheConfig:
"""缓存配置"""
enabled: bool = True
strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY
# 缓存阈值
min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存
early_messages_count: int = 5 # 早期对话缓存的消息数
# 多缓存点配置
multi_point_interval: int = 10 # 多缓存点间隔(消息数)
max_cache_points: int = 4 # 最大缓存点数量
@dataclass
class CacheStats:
"""缓存统计"""
cache_hits: int = 0
cache_misses: int = 0
cached_tokens: int = 0
total_tokens: int = 0
@property
def hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
@property
def token_savings(self) -> float:
return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0
class PromptCacheManager:
"""
Prompt 缓存管理器
负责:
1. 检测 LLM 是否支持缓存
2. 根据对话长度选择缓存策略
3. 为消息添加缓存标记
4. 统计缓存效果
"""
# 支持缓存的模型
CACHEABLE_MODELS = {
# Anthropic Claude
"claude-3-5-sonnet": True,
"claude-3-5-sonnet-20241022": True,
"claude-3-opus": True,
"claude-3-opus-20240229": True,
"claude-3-haiku": True,
"claude-3-haiku-20240307": True,
"claude-3-sonnet": True,
"claude-3-sonnet-20240229": True,
# OpenAI (部分支持)
"gpt-4-turbo": False, # 暂不支持
"gpt-4o": False,
"gpt-4o-mini": False,
}
# Anthropic 缓存标记
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
def __init__(self, config: Optional[CacheConfig] = None):
self.config = config or CacheConfig()
self.stats = CacheStats()
self._cache_enabled_for_session = True
def supports_caching(self, model: str, provider: str) -> bool:
"""
检查模型是否支持缓存
Args:
model: 模型名称
provider: 提供商名称
Returns:
是否支持缓存
"""
if not self.config.enabled:
return False
# Anthropic Claude 支持缓存
if provider.lower() in ["anthropic", "claude"]:
# 检查模型名称
for cacheable_model in self.CACHEABLE_MODELS:
if cacheable_model in model.lower():
return self.CACHEABLE_MODELS.get(cacheable_model, False)
return False
def determine_strategy(
self,
messages: List[Dict[str, Any]],
system_prompt_tokens: int = 0,
) -> CacheStrategy:
"""
根据对话状态确定缓存策略
Args:
messages: 消息列表
system_prompt_tokens: 系统提示词的 token
Returns:
缓存策略
"""
if not self.config.enabled:
return CacheStrategy.NONE
# 系统提示词太短,不值得缓存
if system_prompt_tokens < self.config.min_system_prompt_tokens:
return CacheStrategy.NONE
message_count = len(messages)
# 短对话:仅缓存系统提示词
if message_count < 10:
return CacheStrategy.SYSTEM_ONLY
# 中等对话:缓存系统提示词和早期对话
if message_count < 30:
return CacheStrategy.SYSTEM_AND_EARLY
# 长对话:多缓存点
return CacheStrategy.MULTI_POINT
def add_cache_markers_anthropic(
self,
messages: List[Dict[str, Any]],
strategy: CacheStrategy,
) -> List[Dict[str, Any]]:
"""
Anthropic Claude 消息添加缓存标记
Anthropic 的缓存格式:
- content 中使用 cache_control 字段
- 支持 text 类型的 content block
Args:
messages: 原始消息列表
strategy: 缓存策略
Returns:
添加了缓存标记的消息列表
"""
if strategy == CacheStrategy.NONE:
return messages
cached_messages = []
for i, msg in enumerate(messages):
new_msg = msg.copy()
# 系统提示词缓存
if msg.get("role") == "system":
new_msg = self._add_cache_to_message(new_msg)
cached_messages.append(new_msg)
continue
# 早期对话缓存
if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]:
if i <= self.config.early_messages_count:
new_msg = self._add_cache_to_message(new_msg)
# 多缓存点
if strategy == CacheStrategy.MULTI_POINT:
if i > 0 and i % self.config.multi_point_interval == 0:
cache_point_count = i // self.config.multi_point_interval
if cache_point_count <= self.config.max_cache_points:
new_msg = self._add_cache_to_message(new_msg)
cached_messages.append(new_msg)
return cached_messages
def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""
为单条消息添加缓存标记
Args:
msg: 原始消息
Returns:
添加了缓存标记的消息
"""
content = msg.get("content", "")
# 如果 content 是字符串,转换为 content block 格式
if isinstance(content, str):
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": self.ANTHROPIC_CACHE_CONTROL,
}
]
elif isinstance(content, list):
# 已经是 content block 格式,为最后一个 block 添加缓存
if content:
last_block = content[-1]
if isinstance(last_block, dict):
last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL
return msg
def process_messages(
self,
messages: List[Dict[str, Any]],
model: str,
provider: str,
system_prompt_tokens: int = 0,
) -> Tuple[List[Dict[str, Any]], bool]:
"""
处理消息添加缓存标记
Args:
messages: 原始消息列表
model: 模型名称
provider: 提供商名称
system_prompt_tokens: 系统提示词 token
Returns:
(处理后的消息列表, 是否启用了缓存)
"""
if not self.supports_caching(model, provider):
return messages, False
strategy = self.determine_strategy(messages, system_prompt_tokens)
if strategy == CacheStrategy.NONE:
return messages, False
# 根据提供商选择缓存方法
if provider.lower() in ["anthropic", "claude"]:
cached_messages = self.add_cache_markers_anthropic(messages, strategy)
logger.debug(f"Applied {strategy.value} caching strategy for Anthropic")
return cached_messages, True
return messages, False
def update_stats(
self,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
total_input_tokens: int = 0,
):
"""
更新缓存统计
Args:
cache_creation_input_tokens: 缓存创建的 token
cache_read_input_tokens: 缓存读取的 token
total_input_tokens: 总输入 token
"""
if cache_read_input_tokens > 0:
self.stats.cache_hits += 1
self.stats.cached_tokens += cache_read_input_tokens
else:
self.stats.cache_misses += 1
self.stats.total_tokens += total_input_tokens
def get_stats_summary(self) -> Dict[str, Any]:
"""获取缓存统计摘要"""
return {
"cache_hits": self.stats.cache_hits,
"cache_misses": self.stats.cache_misses,
"hit_rate": f"{self.stats.hit_rate:.2%}",
"cached_tokens": self.stats.cached_tokens,
"total_tokens": self.stats.total_tokens,
"token_savings": f"{self.stats.token_savings:.2%}",
}
# 全局缓存管理器实例
prompt_cache_manager = PromptCacheManager()
def estimate_tokens(text: str) -> int:
"""
估算文本的 token 数量
简单估算英文约 4 字符/token中文约 2 字符/token
Args:
text: 文本内容
Returns:
估算的 token
"""
if not text:
return 0
# 统计中文字符
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
other_chars = len(text) - chinese_chars
# 中文约 2 字符/token其他约 4 字符/token
return int(chinese_chars / 2 + other_chars / 4)