334 lines
10 KiB
Python
334 lines
10 KiB
Python
"""
|
||
Prompt Caching 模块
|
||
|
||
为支持缓存的 LLM(如 Anthropic Claude)提供 Prompt 缓存功能。
|
||
通过在系统提示词和早期对话中添加缓存标记,减少重复处理,
|
||
显著降低 Token 消耗和响应延迟。
|
||
|
||
支持的 LLM:
|
||
- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku)
|
||
- OpenAI (部分模型支持)
|
||
|
||
缓存策略:
|
||
- 短对话(<10轮): 仅缓存系统提示词
|
||
- 中等对话(10-30轮): 缓存系统提示词 + 前5轮对话
|
||
- 长对话(>30轮): 多个缓存点,动态调整
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CacheStrategy(str, Enum):
|
||
"""缓存策略"""
|
||
NONE = "none" # 不缓存
|
||
SYSTEM_ONLY = "system_only" # 仅缓存系统提示词
|
||
SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话
|
||
MULTI_POINT = "multi_point" # 多缓存点
|
||
|
||
|
||
@dataclass
|
||
class CacheConfig:
|
||
"""缓存配置"""
|
||
enabled: bool = True
|
||
strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY
|
||
|
||
# 缓存阈值
|
||
min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存
|
||
early_messages_count: int = 5 # 早期对话缓存的消息数
|
||
|
||
# 多缓存点配置
|
||
multi_point_interval: int = 10 # 多缓存点间隔(消息数)
|
||
max_cache_points: int = 4 # 最大缓存点数量
|
||
|
||
|
||
@dataclass
|
||
class CacheStats:
|
||
"""缓存统计"""
|
||
cache_hits: int = 0
|
||
cache_misses: int = 0
|
||
cached_tokens: int = 0
|
||
total_tokens: int = 0
|
||
|
||
@property
|
||
def hit_rate(self) -> float:
|
||
total = self.cache_hits + self.cache_misses
|
||
return self.cache_hits / total if total > 0 else 0.0
|
||
|
||
@property
|
||
def token_savings(self) -> float:
|
||
return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0
|
||
|
||
|
||
class PromptCacheManager:
|
||
"""
|
||
Prompt 缓存管理器
|
||
|
||
负责:
|
||
1. 检测 LLM 是否支持缓存
|
||
2. 根据对话长度选择缓存策略
|
||
3. 为消息添加缓存标记
|
||
4. 统计缓存效果
|
||
"""
|
||
|
||
# 支持缓存的模型
|
||
CACHEABLE_MODELS = {
|
||
# Anthropic Claude
|
||
"claude-3-5-sonnet": True,
|
||
"claude-3-5-sonnet-20241022": True,
|
||
"claude-3-opus": True,
|
||
"claude-3-opus-20240229": True,
|
||
"claude-3-haiku": True,
|
||
"claude-3-haiku-20240307": True,
|
||
"claude-3-sonnet": True,
|
||
"claude-3-sonnet-20240229": True,
|
||
# OpenAI (部分支持)
|
||
"gpt-4-turbo": False, # 暂不支持
|
||
"gpt-4o": False,
|
||
"gpt-4o-mini": False,
|
||
}
|
||
|
||
# Anthropic 缓存标记
|
||
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
|
||
|
||
def __init__(self, config: Optional[CacheConfig] = None):
|
||
self.config = config or CacheConfig()
|
||
self.stats = CacheStats()
|
||
self._cache_enabled_for_session = True
|
||
|
||
def supports_caching(self, model: str, provider: str) -> bool:
|
||
"""
|
||
检查模型是否支持缓存
|
||
|
||
Args:
|
||
model: 模型名称
|
||
provider: 提供商名称
|
||
|
||
Returns:
|
||
是否支持缓存
|
||
"""
|
||
if not self.config.enabled:
|
||
return False
|
||
|
||
# Anthropic Claude 支持缓存
|
||
if provider.lower() in ["anthropic", "claude"]:
|
||
# 检查模型名称
|
||
for cacheable_model in self.CACHEABLE_MODELS:
|
||
if cacheable_model in model.lower():
|
||
return self.CACHEABLE_MODELS.get(cacheable_model, False)
|
||
|
||
return False
|
||
|
||
def determine_strategy(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
system_prompt_tokens: int = 0,
|
||
) -> CacheStrategy:
|
||
"""
|
||
根据对话状态确定缓存策略
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
system_prompt_tokens: 系统提示词的 token 数
|
||
|
||
Returns:
|
||
缓存策略
|
||
"""
|
||
if not self.config.enabled:
|
||
return CacheStrategy.NONE
|
||
|
||
# 系统提示词太短,不值得缓存
|
||
if system_prompt_tokens < self.config.min_system_prompt_tokens:
|
||
return CacheStrategy.NONE
|
||
|
||
message_count = len(messages)
|
||
|
||
# 短对话:仅缓存系统提示词
|
||
if message_count < 10:
|
||
return CacheStrategy.SYSTEM_ONLY
|
||
|
||
# 中等对话:缓存系统提示词和早期对话
|
||
if message_count < 30:
|
||
return CacheStrategy.SYSTEM_AND_EARLY
|
||
|
||
# 长对话:多缓存点
|
||
return CacheStrategy.MULTI_POINT
|
||
|
||
def add_cache_markers_anthropic(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
strategy: CacheStrategy,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
为 Anthropic Claude 消息添加缓存标记
|
||
|
||
Anthropic 的缓存格式:
|
||
- 在 content 中使用 cache_control 字段
|
||
- 支持 text 类型的 content block
|
||
|
||
Args:
|
||
messages: 原始消息列表
|
||
strategy: 缓存策略
|
||
|
||
Returns:
|
||
添加了缓存标记的消息列表
|
||
"""
|
||
if strategy == CacheStrategy.NONE:
|
||
return messages
|
||
|
||
cached_messages = []
|
||
|
||
for i, msg in enumerate(messages):
|
||
new_msg = msg.copy()
|
||
|
||
# 系统提示词缓存
|
||
if msg.get("role") == "system":
|
||
new_msg = self._add_cache_to_message(new_msg)
|
||
cached_messages.append(new_msg)
|
||
continue
|
||
|
||
# 早期对话缓存
|
||
if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]:
|
||
if i <= self.config.early_messages_count:
|
||
new_msg = self._add_cache_to_message(new_msg)
|
||
|
||
# 多缓存点
|
||
if strategy == CacheStrategy.MULTI_POINT:
|
||
if i > 0 and i % self.config.multi_point_interval == 0:
|
||
cache_point_count = i // self.config.multi_point_interval
|
||
if cache_point_count <= self.config.max_cache_points:
|
||
new_msg = self._add_cache_to_message(new_msg)
|
||
|
||
cached_messages.append(new_msg)
|
||
|
||
return cached_messages
|
||
|
||
def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
为单条消息添加缓存标记
|
||
|
||
Args:
|
||
msg: 原始消息
|
||
|
||
Returns:
|
||
添加了缓存标记的消息
|
||
"""
|
||
content = msg.get("content", "")
|
||
|
||
# 如果 content 是字符串,转换为 content block 格式
|
||
if isinstance(content, str):
|
||
msg["content"] = [
|
||
{
|
||
"type": "text",
|
||
"text": content,
|
||
"cache_control": self.ANTHROPIC_CACHE_CONTROL,
|
||
}
|
||
]
|
||
elif isinstance(content, list):
|
||
# 已经是 content block 格式,为最后一个 block 添加缓存
|
||
if content:
|
||
last_block = content[-1]
|
||
if isinstance(last_block, dict):
|
||
last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL
|
||
|
||
return msg
|
||
|
||
def process_messages(
|
||
self,
|
||
messages: List[Dict[str, Any]],
|
||
model: str,
|
||
provider: str,
|
||
system_prompt_tokens: int = 0,
|
||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||
"""
|
||
处理消息,添加缓存标记
|
||
|
||
Args:
|
||
messages: 原始消息列表
|
||
model: 模型名称
|
||
provider: 提供商名称
|
||
system_prompt_tokens: 系统提示词 token 数
|
||
|
||
Returns:
|
||
(处理后的消息列表, 是否启用了缓存)
|
||
"""
|
||
if not self.supports_caching(model, provider):
|
||
return messages, False
|
||
|
||
strategy = self.determine_strategy(messages, system_prompt_tokens)
|
||
|
||
if strategy == CacheStrategy.NONE:
|
||
return messages, False
|
||
|
||
# 根据提供商选择缓存方法
|
||
if provider.lower() in ["anthropic", "claude"]:
|
||
cached_messages = self.add_cache_markers_anthropic(messages, strategy)
|
||
logger.debug(f"Applied {strategy.value} caching strategy for Anthropic")
|
||
return cached_messages, True
|
||
|
||
return messages, False
|
||
|
||
def update_stats(
|
||
self,
|
||
cache_creation_input_tokens: int = 0,
|
||
cache_read_input_tokens: int = 0,
|
||
total_input_tokens: int = 0,
|
||
):
|
||
"""
|
||
更新缓存统计
|
||
|
||
Args:
|
||
cache_creation_input_tokens: 缓存创建的 token 数
|
||
cache_read_input_tokens: 缓存读取的 token 数
|
||
total_input_tokens: 总输入 token 数
|
||
"""
|
||
if cache_read_input_tokens > 0:
|
||
self.stats.cache_hits += 1
|
||
self.stats.cached_tokens += cache_read_input_tokens
|
||
else:
|
||
self.stats.cache_misses += 1
|
||
|
||
self.stats.total_tokens += total_input_tokens
|
||
|
||
def get_stats_summary(self) -> Dict[str, Any]:
|
||
"""获取缓存统计摘要"""
|
||
return {
|
||
"cache_hits": self.stats.cache_hits,
|
||
"cache_misses": self.stats.cache_misses,
|
||
"hit_rate": f"{self.stats.hit_rate:.2%}",
|
||
"cached_tokens": self.stats.cached_tokens,
|
||
"total_tokens": self.stats.total_tokens,
|
||
"token_savings": f"{self.stats.token_savings:.2%}",
|
||
}
|
||
|
||
|
||
# 全局缓存管理器实例
|
||
prompt_cache_manager = PromptCacheManager()
|
||
|
||
|
||
def estimate_tokens(text: str) -> int:
|
||
"""
|
||
估算文本的 token 数量
|
||
|
||
简单估算:英文约 4 字符/token,中文约 2 字符/token
|
||
|
||
Args:
|
||
text: 文本内容
|
||
|
||
Returns:
|
||
估算的 token 数
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
# 统计中文字符
|
||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
other_chars = len(text) - chinese_chars
|
||
|
||
# 中文约 2 字符/token,其他约 4 字符/token
|
||
return int(chinese_chars / 2 + other_chars / 4)
|