CodeReview/backend/app/services/llm/memory_compressor.py

350 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Memory Compressor - 对话历史压缩器
当对话历史变得很长时自动进行压缩保持语义完整性的同时降低Token消耗。
压缩策略:
1. 保留所有系统消息
2. 保留最近的N条消息
3. 对较早的消息进行摘要压缩
4. 保留关键信息(发现、决策点、错误)
"""
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# 配置常量
MAX_TOTAL_TOKENS = 100_000 # 最大总token数
MIN_RECENT_MESSAGES = 15 # 最少保留的最近消息数
COMPRESSION_THRESHOLD = 0.9 # 触发压缩的阈值90%
def estimate_tokens(text: str) -> int:
"""
估算文本的token数量
简单估算英文约4字符/token中文约2字符/token
"""
if not text:
return 0
# 简单估算
ascii_chars = sum(1 for c in text if ord(c) < 128)
non_ascii_chars = len(text) - ascii_chars
return (ascii_chars // 4) + (non_ascii_chars // 2) + 1
def get_message_tokens(msg: Dict[str, Any]) -> int:
"""获取单条消息的token数"""
content = msg.get("content", "")
if isinstance(content, str):
return estimate_tokens(content)
if isinstance(content, list):
total = 0
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
total += estimate_tokens(item.get("text", ""))
return total
return 0
def extract_message_text(msg: Dict[str, Any]) -> str:
"""提取消息文本内容"""
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
parts.append(item.get("text", ""))
elif item.get("type") == "image_url":
parts.append("[IMAGE]")
return " ".join(parts)
return str(content)
class MemoryCompressor:
"""
对话历史压缩器
当对话历史超过token限制时自动压缩较早的消息
同时保留关键的安全审计上下文。
"""
def __init__(
self,
max_total_tokens: int = MAX_TOTAL_TOKENS,
min_recent_messages: int = MIN_RECENT_MESSAGES,
llm_service=None,
):
"""
初始化压缩器
Args:
max_total_tokens: 最大总token数
min_recent_messages: 最少保留的最近消息数
llm_service: LLM服务用于生成摘要可选
"""
self.max_total_tokens = max_total_tokens
self.min_recent_messages = min_recent_messages
self.llm_service = llm_service
def compress_history(
self,
messages: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
压缩对话历史
策略:
1. 保留所有系统消息
2. 保留最近的N条消息
3. 对较早的消息进行摘要压缩
4. 保留关键信息
Args:
messages: 原始消息列表
Returns:
压缩后的消息列表
"""
if not messages:
return messages
# 分离系统消息和普通消息
system_msgs = []
regular_msgs = []
for msg in messages:
if msg.get("role") == "system":
system_msgs.append(msg)
else:
regular_msgs.append(msg)
# 计算当前总token数
total_tokens = sum(get_message_tokens(msg) for msg in messages)
# 如果未超过阈值,不需要压缩
if total_tokens <= self.max_total_tokens * COMPRESSION_THRESHOLD:
return messages
logger.info(f"Compressing conversation history: {total_tokens} tokens -> target: {int(self.max_total_tokens * 0.7)}")
# 分离最近消息和较早消息
recent_msgs = regular_msgs[-self.min_recent_messages:]
old_msgs = regular_msgs[:-self.min_recent_messages] if len(regular_msgs) > self.min_recent_messages else []
if not old_msgs:
return messages
# 压缩较早的消息
compressed = self._compress_messages(old_msgs)
# 重新组合
result = system_msgs + compressed + recent_msgs
new_total = sum(get_message_tokens(msg) for msg in result)
logger.info(f"Compression complete: {total_tokens} -> {new_total} tokens ({100 - new_total * 100 // total_tokens}% reduction)")
return result
def _compress_messages(
self,
messages: List[Dict[str, Any]],
chunk_size: int = 10,
) -> List[Dict[str, Any]]:
"""
压缩消息列表
Args:
messages: 要压缩的消息
chunk_size: 每次压缩的消息数量
Returns:
压缩后的消息列表
"""
if not messages:
return []
compressed = []
# 按chunk分组压缩
for i in range(0, len(messages), chunk_size):
chunk = messages[i:i + chunk_size]
summary = self._summarize_chunk(chunk)
if summary:
compressed.append(summary)
return compressed
def _summarize_chunk(self, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
摘要一组消息
Args:
messages: 要摘要的消息
Returns:
摘要消息
"""
if not messages:
return None
# 提取关键信息
key_info = self._extract_key_info(messages)
# 构建摘要
summary_parts = []
if key_info["findings"]:
summary_parts.append(f"发现: {', '.join(key_info['findings'][:5])}")
if key_info["tools_used"]:
summary_parts.append(f"使用工具: {', '.join(key_info['tools_used'][:5])}")
if key_info["decisions"]:
summary_parts.append(f"决策: {', '.join(key_info['decisions'][:3])}")
if key_info["errors"]:
summary_parts.append(f"错误: {', '.join(key_info['errors'][:2])}")
if not summary_parts:
# 如果没有提取到关键信息,生成简单摘要
summary_parts.append(f"[已压缩 {len(messages)} 条历史消息]")
summary_text = " | ".join(summary_parts)
return {
"role": "assistant",
"content": f"<context_summary message_count='{len(messages)}'>{summary_text}</context_summary>",
}
def _extract_key_info(self, messages: List[Dict[str, Any]]) -> Dict[str, List[str]]:
"""
从消息中提取关键信息
Args:
messages: 消息列表
Returns:
关键信息字典
"""
import re
key_info = {
"findings": [],
"tools_used": [],
"decisions": [],
"errors": [],
"files_analyzed": [],
}
for msg in messages:
text = extract_message_text(msg).lower()
# 提取发现的漏洞类型
vuln_patterns = {
"sql": "SQL注入",
"xss": "XSS",
"ssrf": "SSRF",
"idor": "IDOR",
"auth": "认证问题",
"injection": "注入漏洞",
"traversal": "路径遍历",
"deserialization": "反序列化",
"hardcoded": "硬编码凭证",
"secret": "密钥泄露",
}
for pattern, label in vuln_patterns.items():
if pattern in text and ("发现" in text or "漏洞" in text or "finding" in text or "vulnerability" in text):
if label not in key_info["findings"]:
key_info["findings"].append(label)
# 提取工具使用
tool_match = re.search(r'action:\s*(\w+)', text, re.IGNORECASE)
if tool_match:
tool = tool_match.group(1)
if tool not in key_info["tools_used"]:
key_info["tools_used"].append(tool)
# 提取分析的文件
file_patterns = [
r'读取文件[:]\s*([^\s\n]+)',
r'分析文件[:]\s*([^\s\n]+)',
r'file[_\s]?path[:]\s*["\']?([^\s\n"\']+)',
r'\.py|\.js|\.ts|\.java|\.go|\.php',
]
for pattern in file_patterns[:3]:
matches = re.findall(pattern, text)
for match in matches:
if match not in key_info["files_analyzed"]:
key_info["files_analyzed"].append(match)
# 提取决策
if any(kw in text for kw in ["决定", "决策", "decision", "选择", "采用"]):
# 尝试提取决策内容
decision_match = re.search(r'(决定|决策|decision)[:\s]*([^\n。.]{10,50})', text)
if decision_match:
key_info["decisions"].append(decision_match.group(2)[:50])
else:
key_info["decisions"].append("做出决策")
# 提取错误
if any(kw in text for kw in ["错误", "失败", "error", "failed", "exception"]):
error_match = re.search(r'(错误|error|failed)[:\s]*([^\n]{10,50})', text, re.IGNORECASE)
if error_match:
key_info["errors"].append(error_match.group(2)[:50])
else:
key_info["errors"].append("遇到错误")
# 去重并限制数量
for key in key_info:
key_info[key] = list(set(key_info[key]))[:5]
return key_info
def should_compress(self, messages: List[Dict[str, Any]]) -> bool:
"""
检查是否需要压缩
Args:
messages: 消息列表
Returns:
是否需要压缩
"""
total_tokens = sum(get_message_tokens(msg) for msg in messages)
return total_tokens > self.max_total_tokens * COMPRESSION_THRESHOLD
# 便捷函数
def compress_conversation(
messages: List[Dict[str, Any]],
max_tokens: int = MAX_TOTAL_TOKENS,
) -> List[Dict[str, Any]]:
"""
压缩对话历史的便捷函数
Args:
messages: 消息列表
max_tokens: 最大token数
Returns:
压缩后的消息列表
"""
compressor = MemoryCompressor(max_total_tokens=max_tokens)
return compressor.compress_history(messages)