350 lines
11 KiB
Python
350 lines
11 KiB
Python
|
|
"""
|
|||
|
|
Memory Compressor - 对话历史压缩器
|
|||
|
|
|
|||
|
|
当对话历史变得很长时,自动进行压缩,保持语义完整性的同时降低Token消耗。
|
|||
|
|
|
|||
|
|
压缩策略:
|
|||
|
|
1. 保留所有系统消息
|
|||
|
|
2. 保留最近的N条消息
|
|||
|
|
3. 对较早的消息进行摘要压缩
|
|||
|
|
4. 保留关键信息(发现、决策点、错误)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Any, Dict, List, Optional
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 配置常量
|
|||
|
|
MAX_TOTAL_TOKENS = 100_000 # 最大总token数
|
|||
|
|
MIN_RECENT_MESSAGES = 15 # 最少保留的最近消息数
|
|||
|
|
COMPRESSION_THRESHOLD = 0.9 # 触发压缩的阈值(90%)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def estimate_tokens(text: str) -> int:
|
|||
|
|
"""
|
|||
|
|
估算文本的token数量
|
|||
|
|
|
|||
|
|
简单估算:英文约4字符/token,中文约2字符/token
|
|||
|
|
"""
|
|||
|
|
if not text:
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
# 简单估算
|
|||
|
|
ascii_chars = sum(1 for c in text if ord(c) < 128)
|
|||
|
|
non_ascii_chars = len(text) - ascii_chars
|
|||
|
|
|
|||
|
|
return (ascii_chars // 4) + (non_ascii_chars // 2) + 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_message_tokens(msg: Dict[str, Any]) -> int:
|
|||
|
|
"""获取单条消息的token数"""
|
|||
|
|
content = msg.get("content", "")
|
|||
|
|
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
return estimate_tokens(content)
|
|||
|
|
|
|||
|
|
if isinstance(content, list):
|
|||
|
|
total = 0
|
|||
|
|
for item in content:
|
|||
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|||
|
|
total += estimate_tokens(item.get("text", ""))
|
|||
|
|
return total
|
|||
|
|
|
|||
|
|
return 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_message_text(msg: Dict[str, Any]) -> str:
|
|||
|
|
"""提取消息文本内容"""
|
|||
|
|
content = msg.get("content", "")
|
|||
|
|
|
|||
|
|
if isinstance(content, str):
|
|||
|
|
return content
|
|||
|
|
|
|||
|
|
if isinstance(content, list):
|
|||
|
|
parts = []
|
|||
|
|
for item in content:
|
|||
|
|
if isinstance(item, dict):
|
|||
|
|
if item.get("type") == "text":
|
|||
|
|
parts.append(item.get("text", ""))
|
|||
|
|
elif item.get("type") == "image_url":
|
|||
|
|
parts.append("[IMAGE]")
|
|||
|
|
return " ".join(parts)
|
|||
|
|
|
|||
|
|
return str(content)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MemoryCompressor:
|
|||
|
|
"""
|
|||
|
|
对话历史压缩器
|
|||
|
|
|
|||
|
|
当对话历史超过token限制时,自动压缩较早的消息,
|
|||
|
|
同时保留关键的安全审计上下文。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
max_total_tokens: int = MAX_TOTAL_TOKENS,
|
|||
|
|
min_recent_messages: int = MIN_RECENT_MESSAGES,
|
|||
|
|
llm_service=None,
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化压缩器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
max_total_tokens: 最大总token数
|
|||
|
|
min_recent_messages: 最少保留的最近消息数
|
|||
|
|
llm_service: LLM服务(用于生成摘要,可选)
|
|||
|
|
"""
|
|||
|
|
self.max_total_tokens = max_total_tokens
|
|||
|
|
self.min_recent_messages = min_recent_messages
|
|||
|
|
self.llm_service = llm_service
|
|||
|
|
|
|||
|
|
def compress_history(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
压缩对话历史
|
|||
|
|
|
|||
|
|
策略:
|
|||
|
|
1. 保留所有系统消息
|
|||
|
|
2. 保留最近的N条消息
|
|||
|
|
3. 对较早的消息进行摘要压缩
|
|||
|
|
4. 保留关键信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 原始消息列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
压缩后的消息列表
|
|||
|
|
"""
|
|||
|
|
if not messages:
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
# 分离系统消息和普通消息
|
|||
|
|
system_msgs = []
|
|||
|
|
regular_msgs = []
|
|||
|
|
|
|||
|
|
for msg in messages:
|
|||
|
|
if msg.get("role") == "system":
|
|||
|
|
system_msgs.append(msg)
|
|||
|
|
else:
|
|||
|
|
regular_msgs.append(msg)
|
|||
|
|
|
|||
|
|
# 计算当前总token数
|
|||
|
|
total_tokens = sum(get_message_tokens(msg) for msg in messages)
|
|||
|
|
|
|||
|
|
# 如果未超过阈值,不需要压缩
|
|||
|
|
if total_tokens <= self.max_total_tokens * COMPRESSION_THRESHOLD:
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
logger.info(f"Compressing conversation history: {total_tokens} tokens -> target: {int(self.max_total_tokens * 0.7)}")
|
|||
|
|
|
|||
|
|
# 分离最近消息和较早消息
|
|||
|
|
recent_msgs = regular_msgs[-self.min_recent_messages:]
|
|||
|
|
old_msgs = regular_msgs[:-self.min_recent_messages] if len(regular_msgs) > self.min_recent_messages else []
|
|||
|
|
|
|||
|
|
if not old_msgs:
|
|||
|
|
return messages
|
|||
|
|
|
|||
|
|
# 压缩较早的消息
|
|||
|
|
compressed = self._compress_messages(old_msgs)
|
|||
|
|
|
|||
|
|
# 重新组合
|
|||
|
|
result = system_msgs + compressed + recent_msgs
|
|||
|
|
|
|||
|
|
new_total = sum(get_message_tokens(msg) for msg in result)
|
|||
|
|
logger.info(f"Compression complete: {total_tokens} -> {new_total} tokens ({100 - new_total * 100 // total_tokens}% reduction)")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def _compress_messages(
|
|||
|
|
self,
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
chunk_size: int = 10,
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
压缩消息列表
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 要压缩的消息
|
|||
|
|
chunk_size: 每次压缩的消息数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
压缩后的消息列表
|
|||
|
|
"""
|
|||
|
|
if not messages:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
compressed = []
|
|||
|
|
|
|||
|
|
# 按chunk分组压缩
|
|||
|
|
for i in range(0, len(messages), chunk_size):
|
|||
|
|
chunk = messages[i:i + chunk_size]
|
|||
|
|
summary = self._summarize_chunk(chunk)
|
|||
|
|
if summary:
|
|||
|
|
compressed.append(summary)
|
|||
|
|
|
|||
|
|
return compressed
|
|||
|
|
|
|||
|
|
def _summarize_chunk(self, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
摘要一组消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 要摘要的消息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
摘要消息
|
|||
|
|
"""
|
|||
|
|
if not messages:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 提取关键信息
|
|||
|
|
key_info = self._extract_key_info(messages)
|
|||
|
|
|
|||
|
|
# 构建摘要
|
|||
|
|
summary_parts = []
|
|||
|
|
|
|||
|
|
if key_info["findings"]:
|
|||
|
|
summary_parts.append(f"发现: {', '.join(key_info['findings'][:5])}")
|
|||
|
|
|
|||
|
|
if key_info["tools_used"]:
|
|||
|
|
summary_parts.append(f"使用工具: {', '.join(key_info['tools_used'][:5])}")
|
|||
|
|
|
|||
|
|
if key_info["decisions"]:
|
|||
|
|
summary_parts.append(f"决策: {', '.join(key_info['decisions'][:3])}")
|
|||
|
|
|
|||
|
|
if key_info["errors"]:
|
|||
|
|
summary_parts.append(f"错误: {', '.join(key_info['errors'][:2])}")
|
|||
|
|
|
|||
|
|
if not summary_parts:
|
|||
|
|
# 如果没有提取到关键信息,生成简单摘要
|
|||
|
|
summary_parts.append(f"[已压缩 {len(messages)} 条历史消息]")
|
|||
|
|
|
|||
|
|
summary_text = " | ".join(summary_parts)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"role": "assistant",
|
|||
|
|
"content": f"<context_summary message_count='{len(messages)}'>{summary_text}</context_summary>",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def _extract_key_info(self, messages: List[Dict[str, Any]]) -> Dict[str, List[str]]:
|
|||
|
|
"""
|
|||
|
|
从消息中提取关键信息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
关键信息字典
|
|||
|
|
"""
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
key_info = {
|
|||
|
|
"findings": [],
|
|||
|
|
"tools_used": [],
|
|||
|
|
"decisions": [],
|
|||
|
|
"errors": [],
|
|||
|
|
"files_analyzed": [],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for msg in messages:
|
|||
|
|
text = extract_message_text(msg).lower()
|
|||
|
|
|
|||
|
|
# 提取发现的漏洞类型
|
|||
|
|
vuln_patterns = {
|
|||
|
|
"sql": "SQL注入",
|
|||
|
|
"xss": "XSS",
|
|||
|
|
"ssrf": "SSRF",
|
|||
|
|
"idor": "IDOR",
|
|||
|
|
"auth": "认证问题",
|
|||
|
|
"injection": "注入漏洞",
|
|||
|
|
"traversal": "路径遍历",
|
|||
|
|
"deserialization": "反序列化",
|
|||
|
|
"hardcoded": "硬编码凭证",
|
|||
|
|
"secret": "密钥泄露",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for pattern, label in vuln_patterns.items():
|
|||
|
|
if pattern in text and ("发现" in text or "漏洞" in text or "finding" in text or "vulnerability" in text):
|
|||
|
|
if label not in key_info["findings"]:
|
|||
|
|
key_info["findings"].append(label)
|
|||
|
|
|
|||
|
|
# 提取工具使用
|
|||
|
|
tool_match = re.search(r'action:\s*(\w+)', text, re.IGNORECASE)
|
|||
|
|
if tool_match:
|
|||
|
|
tool = tool_match.group(1)
|
|||
|
|
if tool not in key_info["tools_used"]:
|
|||
|
|
key_info["tools_used"].append(tool)
|
|||
|
|
|
|||
|
|
# 提取分析的文件
|
|||
|
|
file_patterns = [
|
|||
|
|
r'读取文件[::]\s*([^\s\n]+)',
|
|||
|
|
r'分析文件[::]\s*([^\s\n]+)',
|
|||
|
|
r'file[_\s]?path[::]\s*["\']?([^\s\n"\']+)',
|
|||
|
|
r'\.py|\.js|\.ts|\.java|\.go|\.php',
|
|||
|
|
]
|
|||
|
|
for pattern in file_patterns[:3]:
|
|||
|
|
matches = re.findall(pattern, text)
|
|||
|
|
for match in matches:
|
|||
|
|
if match not in key_info["files_analyzed"]:
|
|||
|
|
key_info["files_analyzed"].append(match)
|
|||
|
|
|
|||
|
|
# 提取决策
|
|||
|
|
if any(kw in text for kw in ["决定", "决策", "decision", "选择", "采用"]):
|
|||
|
|
# 尝试提取决策内容
|
|||
|
|
decision_match = re.search(r'(决定|决策|decision)[::\s]*([^\n。.]{10,50})', text)
|
|||
|
|
if decision_match:
|
|||
|
|
key_info["decisions"].append(decision_match.group(2)[:50])
|
|||
|
|
else:
|
|||
|
|
key_info["decisions"].append("做出决策")
|
|||
|
|
|
|||
|
|
# 提取错误
|
|||
|
|
if any(kw in text for kw in ["错误", "失败", "error", "failed", "exception"]):
|
|||
|
|
error_match = re.search(r'(错误|error|failed)[::\s]*([^\n]{10,50})', text, re.IGNORECASE)
|
|||
|
|
if error_match:
|
|||
|
|
key_info["errors"].append(error_match.group(2)[:50])
|
|||
|
|
else:
|
|||
|
|
key_info["errors"].append("遇到错误")
|
|||
|
|
|
|||
|
|
# 去重并限制数量
|
|||
|
|
for key in key_info:
|
|||
|
|
key_info[key] = list(set(key_info[key]))[:5]
|
|||
|
|
|
|||
|
|
return key_info
|
|||
|
|
|
|||
|
|
def should_compress(self, messages: List[Dict[str, Any]]) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查是否需要压缩
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否需要压缩
|
|||
|
|
"""
|
|||
|
|
total_tokens = sum(get_message_tokens(msg) for msg in messages)
|
|||
|
|
return total_tokens > self.max_total_tokens * COMPRESSION_THRESHOLD
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 便捷函数
|
|||
|
|
def compress_conversation(
|
|||
|
|
messages: List[Dict[str, Any]],
|
|||
|
|
max_tokens: int = MAX_TOTAL_TOKENS,
|
|||
|
|
) -> List[Dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
压缩对话历史的便捷函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 消息列表
|
|||
|
|
max_tokens: 最大token数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
压缩后的消息列表
|
|||
|
|
"""
|
|||
|
|
compressor = MemoryCompressor(max_total_tokens=max_tokens)
|
|||
|
|
return compressor.compress_history(messages)
|