334 lines
10 KiB
Python
334 lines
10 KiB
Python
|
|
"""
|
|||
|
|
Prompt Caching 模块
|
|||
|
|
|
|||
|
|
为支持缓存的 LLM(如 Anthropic Claude)提供 Prompt 缓存功能。
|
|||
|
|
通过在系统提示词和早期对话中添加缓存标记,减少重复处理,
|
|||
|
|
显著降低 Token 消耗和响应延迟。
|
|||
|
|
|
|||
|
|
支持的 LLM:
|
|||
|
|
- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku)
|
|||
|
|
- OpenAI (部分模型支持)
|
|||
|
|
|
|||
|
|
缓存策略:
|
|||
|
|
- 短对话(<10轮): 仅缓存系统提示词
|
|||
|
|
- 中等对话(10-30轮): 缓存系统提示词 + 前5轮对话
|
|||
|
|
- 长对话(>30轮): 多个缓存点,动态调整
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from enum import Enum
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CacheStrategy(str, Enum):
|
|||
|
|
"""缓存策略"""
|
|||
|
|
NONE = "none" # 不缓存
|
|||
|
|
SYSTEM_ONLY = "system_only" # 仅缓存系统提示词
|
|||
|
|
SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话
|
|||
|
|
MULTI_POINT = "multi_point" # 多缓存点
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CacheConfig:
|
|||
|
|
"""缓存配置"""
|
|||
|
|
enabled: bool = True
|
|||
|
|
strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY
|
|||
|
|
|
|||
|
|
# 缓存阈值
|
|||
|
|
min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存
|
|||
|
|
early_messages_count: int = 5 # 早期对话缓存的消息数
|
|||
|
|
|
|||
|
|
# 多缓存点配置
|
|||
|
|
multi_point_interval: int = 10 # 多缓存点间隔(消息数)
|
|||
|
|
max_cache_points: int = 4 # 最大缓存点数量
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CacheStats:
|
|||
|
|
"""缓存统计"""
|
|||
|
|
cache_hits: int = 0
|
|||
|
|
cache_misses: int = 0
|
|||
|
|
cached_tokens: int = 0
|
|||
|
|
total_tokens: int = 0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def hit_rate(self) -> float:
|
|||
|
|
total = self.cache_hits + self.cache_misses
|
|||
|
|
return self.cache_hits / total if total > 0 else 0.0
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def token_savings(self) -> float:
|
|||
|
|
return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PromptCacheManager:
|
|||
|
|
"""
|
|||
|
|
Prompt 缓存管理器
|
|||
|
|
|
|||
|
|
负责:
|
|||
|
|
1. 检测 LLM 是否支持缓存
|
|||
|
|
2. 根据对话长度选择缓存策略
|
|||
|
|
3. 为消息添加缓存标记
|
|||
|
|
4. 统计缓存效果
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 支持缓存的模型
|
|||
|
|
CACHEABLE_MODELS = {
|
|||
|
|
# Anthropic Claude
|
|||
|
|
"claude-3-5-sonnet": True,
|
|||
|
|
"claude-3-5-sonnet-20241022": True,
|
|||
|
|
"claude-3-opus": True,
|
|||
|
|
"claude-3-opus-20240229": True,
|
|||
|
|
"claude-3-haiku": True,
|
|||
|
|
"claude-3-haiku-20240307": True,
|
|||
|
|
"claude-3-sonnet": True,
|
|||
|
|
"claude-3-sonnet-20240229": True,
|
|||
|
|
# OpenAI (部分支持)
|
|||
|
|
"gpt-4-turbo": False, # 暂不支持
|
|||
|
|
"gpt-4o": False,
|
|||
|
|
"gpt-4o-mini": False,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Anthropic 缓存标记
|
|||
|
|
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
|
|||
|
|
|
|||
|
|
def __init__(self, config: Optional[CacheConfig] = None):
|
|||
|
|
self.config = config or CacheConfig()
|
|||
|
|
self.stats = CacheStats()
|
|||
|
|
self._cache_enabled_for_session = True
|
|||
|
|
|
|||
|
|
def supports_caching(self, model: str, provider: str) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查模型是否支持缓存
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model: 模型名称
|
|||
|
|
provider: 提供商名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否支持缓存
|
|||
|
|
"""
|
|||
|
|
if not self.config.enabled:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# Anthropic Claude 支持缓存
|
|||
|
|
if provider.lower() in ["anthropic", "claude"]:
|
|||
|
|
# 检查模型名称
|
|||
|
|
for cacheable_model in self.CACHEABLE_MODELS:
|
|||
|
|
if cacheable_model in model.lower():
|
|||
|
|
return self.CACHEABLE_MODELS.get(cacheable_model, False)
|
|||
|
|
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def determine_strategy(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
system_prompt_tokens: int = 0,
|
|||
|
|
) -> CacheStrategy:
|
|||
|
|
"""
|
|||
|
|
根据对话状态确定缓存策略
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
system_prompt_tokens: 系统提示词的 token 数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
缓存策略
|
|||
|
|
"""
|
|||
|
|
if not self.config.enabled:
|
|||
|
|
return CacheStrategy.NONE
|
|||
|
|
|
|||
|
|
# 系统提示词太短,不值得缓存
|
|||
|
|
if system_prompt_tokens < self.config.min_system_prompt_tokens:
|
|||
|
|
return CacheStrategy.NONE
|
|||
|
|
|
|||
|
|
message_count = len(messages)
|
|||
|
|
|
|||
|
|
# 短对话:仅缓存系统提示词
|
|||
|
|
if message_count < 10:
|
|||
|
|
return CacheStrategy.SYSTEM_ONLY
|
|||
|
|
|
|||
|
|
# 中等对话:缓存系统提示词和早期对话
|
|||
|
|
if message_count < 30:
|
|||
|
|
return CacheStrategy.SYSTEM_AND_EARLY
|
|||
|
|
|
|||
|
|
# 长对话:多缓存点
|
|||
|
|
return CacheStrategy.MULTI_POINT
|
|||
|
|
|
|||
|
|
def add_cache_markers_anthropic(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
strategy: CacheStrategy,
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
为 Anthropic Claude 消息添加缓存标记
|
|||
|
|
|
|||
|
|
Anthropic 的缓存格式:
|
|||
|
|
- 在 content 中使用 cache_control 字段
|
|||
|
|
- 支持 text 类型的 content block
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 原始消息列表
|
|||
|
|
strategy: 缓存策略
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
添加了缓存标记的消息列表
|
|||
|
|
"""
|
|||
|
|
if strategy == CacheStrategy.NONE:
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
cached_messages = []
|
|||
|
|
|
|||
|
|
for i, msg in enumerate(messages):
|
|||
|
|
new_msg = msg.copy()
|
|||
|
|
|
|||
|
|
# 系统提示词缓存
|
|||
|
|
if msg.get("role") == "system":
|
|||
|
|
new_msg = self._add_cache_to_message(new_msg)
|
|||
|
|
cached_messages.append(new_msg)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 早期对话缓存
|
|||
|
|
if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]:
|
|||
|
|
if i <= self.config.early_messages_count:
|
|||
|
|
new_msg = self._add_cache_to_message(new_msg)
|
|||
|
|
|
|||
|
|
# 多缓存点
|
|||
|
|
if strategy == CacheStrategy.MULTI_POINT:
|
|||
|
|
if i > 0 and i % self.config.multi_point_interval == 0:
|
|||
|
|
cache_point_count = i // self.config.multi_point_interval
|
|||
|
|
if cache_point_count <= self.config.max_cache_points:
|
|||
|
|
new_msg = self._add_cache_to_message(new_msg)
|
|||
|
|
|
|||
|
|
cached_messages.append(new_msg)
|
|||
|
|
|
|||
|
|
return cached_messages
|
|||
|
|
|
|||
|
|
def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
为单条消息添加缓存标记
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
msg: 原始消息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
添加了缓存标记的消息
|
|||
|
|
"""
|
|||
|
|
content = msg.get("content", "")
|
|||
|
|
|
|||
|
|
# 如果 content 是字符串,转换为 content block 格式
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
msg["content"] = [
|
|||
|
|
{
|
|||
|
|
"type": "text",
|
|||
|
|
"text": content,
|
|||
|
|
"cache_control": self.ANTHROPIC_CACHE_CONTROL,
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
elif isinstance(content, list):
|
|||
|
|
# 已经是 content block 格式,为最后一个 block 添加缓存
|
|||
|
|
if content:
|
|||
|
|
last_block = content[-1]
|
|||
|
|
if isinstance(last_block, dict):
|
|||
|
|
last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL
|
|||
|
|
|
|||
|
|
return msg
|
|||
|
|
|
|||
|
|
def process_messages(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
model: str,
|
|||
|
|
provider: str,
|
|||
|
|
system_prompt_tokens: int = 0,
|
|||
|
|
) -> Tuple[List[Dict[str, Any]], bool]:
|
|||
|
|
"""
|
|||
|
|
处理消息,添加缓存标记
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 原始消息列表
|
|||
|
|
model: 模型名称
|
|||
|
|
provider: 提供商名称
|
|||
|
|
system_prompt_tokens: 系统提示词 token 数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(处理后的消息列表, 是否启用了缓存)
|
|||
|
|
"""
|
|||
|
|
if not self.supports_caching(model, provider):
|
|||
|
|
return messages, False
|
|||
|
|
|
|||
|
|
strategy = self.determine_strategy(messages, system_prompt_tokens)
|
|||
|
|
|
|||
|
|
if strategy == CacheStrategy.NONE:
|
|||
|
|
return messages, False
|
|||
|
|
|
|||
|
|
# 根据提供商选择缓存方法
|
|||
|
|
if provider.lower() in ["anthropic", "claude"]:
|
|||
|
|
cached_messages = self.add_cache_markers_anthropic(messages, strategy)
|
|||
|
|
logger.debug(f"Applied {strategy.value} caching strategy for Anthropic")
|
|||
|
|
return cached_messages, True
|
|||
|
|
|
|||
|
|
return messages, False
|
|||
|
|
|
|||
|
|
def update_stats(
|
|||
|
|
self,
|
|||
|
|
cache_creation_input_tokens: int = 0,
|
|||
|
|
cache_read_input_tokens: int = 0,
|
|||
|
|
total_input_tokens: int = 0,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
更新缓存统计
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
cache_creation_input_tokens: 缓存创建的 token 数
|
|||
|
|
cache_read_input_tokens: 缓存读取的 token 数
|
|||
|
|
total_input_tokens: 总输入 token 数
|
|||
|
|
"""
|
|||
|
|
if cache_read_input_tokens > 0:
|
|||
|
|
self.stats.cache_hits += 1
|
|||
|
|
self.stats.cached_tokens += cache_read_input_tokens
|
|||
|
|
else:
|
|||
|
|
self.stats.cache_misses += 1
|
|||
|
|
|
|||
|
|
self.stats.total_tokens += total_input_tokens
|
|||
|
|
|
|||
|
|
def get_stats_summary(self) -> Dict[str, Any]:
|
|||
|
|
"""获取缓存统计摘要"""
|
|||
|
|
return {
|
|||
|
|
"cache_hits": self.stats.cache_hits,
|
|||
|
|
"cache_misses": self.stats.cache_misses,
|
|||
|
|
"hit_rate": f"{self.stats.hit_rate:.2%}",
|
|||
|
|
"cached_tokens": self.stats.cached_tokens,
|
|||
|
|
"total_tokens": self.stats.total_tokens,
|
|||
|
|
"token_savings": f"{self.stats.token_savings:.2%}",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局缓存管理器实例
|
|||
|
|
prompt_cache_manager = PromptCacheManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def estimate_tokens(text: str) -> int:
|
|||
|
|
"""
|
|||
|
|
估算文本的 token 数量
|
|||
|
|
|
|||
|
|
简单估算:英文约 4 字符/token,中文约 2 字符/token
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: 文本内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
估算的 token 数
|
|||
|
|
"""
|
|||
|
|
if not text:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
# 统计中文字符
|
|||
|
|
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
|||
|
|
other_chars = len(text) - chinese_chars
|
|||
|
|
|
|||
|
|
# 中文约 2 字符/token,其他约 4 字符/token
|
|||
|
|
return int(chinese_chars / 2 + other_chars / 4)
|