454 lines
15 KiB
Python
454 lines
15 KiB
Python
"""
|
|
流式事件处理器
|
|
处理 LangGraph 的各种流式事件并转换为前端可消费的格式
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from enum import Enum
|
|
from typing import Any, Dict, Optional, AsyncGenerator, List
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StreamEventType(str, Enum):
|
|
"""流式事件类型"""
|
|
# LLM 相关
|
|
THINKING_START = "thinking_start" # 开始思考
|
|
THINKING_TOKEN = "thinking_token" # 思考 Token
|
|
THINKING_END = "thinking_end" # 思考结束
|
|
|
|
# 工具调用相关
|
|
TOOL_CALL_START = "tool_call_start" # 工具调用开始
|
|
TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数
|
|
TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果
|
|
TOOL_CALL_END = "tool_call_end" # 工具调用结束
|
|
TOOL_CALL_ERROR = "tool_call_error" # 工具调用错误
|
|
|
|
# 节点相关
|
|
NODE_START = "node_start" # 节点开始
|
|
NODE_END = "node_end" # 节点结束
|
|
|
|
# 阶段相关
|
|
PHASE_START = "phase_start"
|
|
PHASE_END = "phase_end"
|
|
|
|
# 发现相关
|
|
FINDING_NEW = "finding_new" # 新发现
|
|
FINDING_VERIFIED = "finding_verified" # 验证通过
|
|
|
|
# 状态相关
|
|
PROGRESS = "progress"
|
|
INFO = "info"
|
|
WARNING = "warning"
|
|
ERROR = "error"
|
|
|
|
# 任务相关
|
|
TASK_START = "task_start"
|
|
TASK_COMPLETE = "task_complete"
|
|
TASK_ERROR = "task_error"
|
|
TASK_CANCEL = "task_cancel"
|
|
|
|
# 心跳
|
|
HEARTBEAT = "heartbeat"
|
|
|
|
|
|
@dataclass
|
|
class StreamEvent:
|
|
"""流式事件"""
|
|
event_type: StreamEventType
|
|
data: Dict[str, Any] = field(default_factory=dict)
|
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
sequence: int = 0
|
|
|
|
# 可选字段
|
|
node_name: Optional[str] = None
|
|
phase: Optional[str] = None
|
|
tool_name: Optional[str] = None
|
|
|
|
def to_sse(self) -> str:
|
|
"""转换为 SSE 格式"""
|
|
event_data = {
|
|
"type": self.event_type.value,
|
|
"data": self.data,
|
|
"timestamp": self.timestamp,
|
|
"sequence": self.sequence,
|
|
}
|
|
|
|
if self.node_name:
|
|
event_data["node"] = self.node_name
|
|
if self.phase:
|
|
event_data["phase"] = self.phase
|
|
if self.tool_name:
|
|
event_data["tool"] = self.tool_name
|
|
|
|
return f"event: {self.event_type.value}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典"""
|
|
return {
|
|
"event_type": self.event_type.value,
|
|
"data": self.data,
|
|
"timestamp": self.timestamp,
|
|
"sequence": self.sequence,
|
|
"node_name": self.node_name,
|
|
"phase": self.phase,
|
|
"tool_name": self.tool_name,
|
|
}
|
|
|
|
|
|
class StreamHandler:
|
|
"""
|
|
流式事件处理器
|
|
|
|
最佳实践:
|
|
1. 使用 astream_events 捕获所有 LangGraph 事件
|
|
2. 将内部事件转换为前端友好的格式
|
|
3. 支持多种事件类型的分发
|
|
"""
|
|
|
|
def __init__(self, task_id: str):
|
|
self.task_id = task_id
|
|
self._sequence = 0
|
|
self._current_phase = None
|
|
self._current_node = None
|
|
self._thinking_buffer = []
|
|
self._tool_states: Dict[str, Dict] = {}
|
|
|
|
def _next_sequence(self) -> int:
|
|
"""获取下一个序列号"""
|
|
self._sequence += 1
|
|
return self._sequence
|
|
|
|
async def process_langgraph_event(self, event: Dict[str, Any]) -> Optional[StreamEvent]:
|
|
"""
|
|
处理 LangGraph 事件
|
|
|
|
支持的事件类型:
|
|
- on_chain_start: 链/节点开始
|
|
- on_chain_end: 链/节点结束
|
|
- on_chain_stream: LLM Token 流
|
|
- on_chat_model_start: 模型开始
|
|
- on_chat_model_stream: 模型 Token 流
|
|
- on_chat_model_end: 模型结束
|
|
- on_tool_start: 工具开始
|
|
- on_tool_end: 工具结束
|
|
- on_custom_event: 自定义事件
|
|
"""
|
|
event_kind = event.get("event", "")
|
|
event_name = event.get("name", "")
|
|
event_data = event.get("data", {})
|
|
|
|
# LLM Token 流
|
|
if event_kind == "on_chat_model_stream":
|
|
return await self._handle_llm_stream(event_data, event_name)
|
|
|
|
# LLM 开始
|
|
elif event_kind == "on_chat_model_start":
|
|
return await self._handle_llm_start(event_data, event_name)
|
|
|
|
# LLM 结束
|
|
elif event_kind == "on_chat_model_end":
|
|
return await self._handle_llm_end(event_data, event_name)
|
|
|
|
# 工具开始
|
|
elif event_kind == "on_tool_start":
|
|
return await self._handle_tool_start(event_name, event_data)
|
|
|
|
# 工具结束
|
|
elif event_kind == "on_tool_end":
|
|
return await self._handle_tool_end(event_name, event_data)
|
|
|
|
# 节点开始
|
|
elif event_kind == "on_chain_start" and self._is_node_event(event_name):
|
|
return await self._handle_node_start(event_name, event_data)
|
|
|
|
# 节点结束
|
|
elif event_kind == "on_chain_end" and self._is_node_event(event_name):
|
|
return await self._handle_node_end(event_name, event_data)
|
|
|
|
# 自定义事件
|
|
elif event_kind == "on_custom_event":
|
|
return await self._handle_custom_event(event_name, event_data)
|
|
|
|
return None
|
|
|
|
def _is_node_event(self, name: str) -> bool:
|
|
"""判断是否是节点事件"""
|
|
node_names = ["recon", "analysis", "verification", "report", "ReconNode", "AnalysisNode", "VerificationNode", "ReportNode"]
|
|
return any(n.lower() in name.lower() for n in node_names)
|
|
|
|
async def _handle_llm_start(self, data: Dict, name: str) -> StreamEvent:
|
|
"""处理 LLM 开始事件"""
|
|
self._thinking_buffer = []
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.THINKING_START,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data={
|
|
"model": name,
|
|
"message": "🤔 正在思考...",
|
|
},
|
|
)
|
|
|
|
async def _handle_llm_stream(self, data: Dict, name: str) -> Optional[StreamEvent]:
|
|
"""处理 LLM Token 流事件"""
|
|
chunk = data.get("chunk")
|
|
if not chunk:
|
|
return None
|
|
|
|
# 提取 Token 内容
|
|
content = ""
|
|
if hasattr(chunk, "content"):
|
|
content = chunk.content
|
|
elif isinstance(chunk, dict):
|
|
content = chunk.get("content", "")
|
|
|
|
if not content:
|
|
return None
|
|
|
|
# 添加到缓冲区
|
|
self._thinking_buffer.append(content)
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.THINKING_TOKEN,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data={
|
|
"token": content,
|
|
"accumulated": "".join(self._thinking_buffer),
|
|
},
|
|
)
|
|
|
|
async def _handle_llm_end(self, data: Dict, name: str) -> StreamEvent:
|
|
"""处理 LLM 结束事件"""
|
|
full_response = "".join(self._thinking_buffer)
|
|
self._thinking_buffer = []
|
|
|
|
# 提取使用的 Token 数
|
|
usage = {}
|
|
output = data.get("output")
|
|
if output and hasattr(output, "usage_metadata"):
|
|
usage = {
|
|
"input_tokens": getattr(output.usage_metadata, "input_tokens", 0),
|
|
"output_tokens": getattr(output.usage_metadata, "output_tokens", 0),
|
|
}
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.THINKING_END,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data={
|
|
"response": full_response[:2000], # 截断长响应
|
|
"usage": usage,
|
|
"message": "💡 思考完成",
|
|
},
|
|
)
|
|
|
|
async def _handle_tool_start(self, tool_name: str, data: Dict) -> StreamEvent:
|
|
"""处理工具开始事件"""
|
|
import time
|
|
|
|
tool_input = data.get("input", {})
|
|
|
|
# 记录工具状态
|
|
self._tool_states[tool_name] = {
|
|
"start_time": time.time(),
|
|
"input": tool_input,
|
|
}
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.TOOL_CALL_START,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
tool_name=tool_name,
|
|
data={
|
|
"tool_name": tool_name,
|
|
"input": self._truncate_data(tool_input),
|
|
"message": f"🔧 调用工具: {tool_name}",
|
|
},
|
|
)
|
|
|
|
async def _handle_tool_end(self, tool_name: str, data: Dict) -> StreamEvent:
|
|
"""处理工具结束事件"""
|
|
import time
|
|
|
|
# 计算执行时间
|
|
duration_ms = 0
|
|
if tool_name in self._tool_states:
|
|
start_time = self._tool_states[tool_name].get("start_time", time.time())
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
del self._tool_states[tool_name]
|
|
|
|
# 提取输出
|
|
output = data.get("output", "")
|
|
if hasattr(output, "content"):
|
|
output = output.content
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.TOOL_CALL_END,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
tool_name=tool_name,
|
|
data={
|
|
"tool_name": tool_name,
|
|
"output": self._truncate_data(output),
|
|
"duration_ms": duration_ms,
|
|
"message": f"✅ 工具 {tool_name} 完成 ({duration_ms}ms)",
|
|
},
|
|
)
|
|
|
|
async def _handle_node_start(self, node_name: str, data: Dict) -> StreamEvent:
|
|
"""处理节点开始事件"""
|
|
self._current_node = node_name
|
|
|
|
# 映射节点到阶段
|
|
phase_map = {
|
|
"recon": "reconnaissance",
|
|
"analysis": "analysis",
|
|
"verification": "verification",
|
|
"report": "reporting",
|
|
}
|
|
|
|
for key, phase in phase_map.items():
|
|
if key in node_name.lower():
|
|
self._current_phase = phase
|
|
break
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.NODE_START,
|
|
sequence=self._next_sequence(),
|
|
node_name=node_name,
|
|
phase=self._current_phase,
|
|
data={
|
|
"node": node_name,
|
|
"phase": self._current_phase,
|
|
"message": f"▶️ 开始节点: {node_name}",
|
|
},
|
|
)
|
|
|
|
async def _handle_node_end(self, node_name: str, data: Dict) -> StreamEvent:
|
|
"""处理节点结束事件"""
|
|
# 提取输出信息
|
|
output = data.get("output", {})
|
|
|
|
summary = {}
|
|
if isinstance(output, dict):
|
|
# 提取关键信息
|
|
if "findings" in output:
|
|
summary["findings_count"] = len(output["findings"])
|
|
if "entry_points" in output:
|
|
summary["entry_points_count"] = len(output["entry_points"])
|
|
if "high_risk_areas" in output:
|
|
summary["high_risk_areas_count"] = len(output["high_risk_areas"])
|
|
if "verified_findings" in output:
|
|
summary["verified_count"] = len(output["verified_findings"])
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.NODE_END,
|
|
sequence=self._next_sequence(),
|
|
node_name=node_name,
|
|
phase=self._current_phase,
|
|
data={
|
|
"node": node_name,
|
|
"phase": self._current_phase,
|
|
"summary": summary,
|
|
"message": f"⏹️ 节点完成: {node_name}",
|
|
},
|
|
)
|
|
|
|
async def _handle_custom_event(self, event_name: str, data: Dict) -> StreamEvent:
|
|
"""处理自定义事件"""
|
|
# 映射自定义事件名到事件类型
|
|
event_type_map = {
|
|
"finding": StreamEventType.FINDING_NEW,
|
|
"finding_verified": StreamEventType.FINDING_VERIFIED,
|
|
"progress": StreamEventType.PROGRESS,
|
|
"warning": StreamEventType.WARNING,
|
|
"error": StreamEventType.ERROR,
|
|
}
|
|
|
|
event_type = event_type_map.get(event_name, StreamEventType.INFO)
|
|
|
|
return StreamEvent(
|
|
event_type=event_type,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data=data,
|
|
)
|
|
|
|
def _truncate_data(self, data: Any, max_length: int = 1000) -> Any:
|
|
"""截断数据"""
|
|
if isinstance(data, str):
|
|
return data[:max_length] + "..." if len(data) > max_length else data
|
|
elif isinstance(data, dict):
|
|
return {k: self._truncate_data(v, max_length // 2) for k, v in list(data.items())[:10]}
|
|
elif isinstance(data, list):
|
|
return [self._truncate_data(item, max_length // len(data)) for item in data[:10]]
|
|
else:
|
|
return str(data)[:max_length]
|
|
|
|
def create_progress_event(
|
|
self,
|
|
current: int,
|
|
total: int,
|
|
message: Optional[str] = None,
|
|
) -> StreamEvent:
|
|
"""创建进度事件"""
|
|
percentage = (current / total * 100) if total > 0 else 0
|
|
|
|
return StreamEvent(
|
|
event_type=StreamEventType.PROGRESS,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data={
|
|
"current": current,
|
|
"total": total,
|
|
"percentage": round(percentage, 1),
|
|
"message": message or f"进度: {current}/{total}",
|
|
},
|
|
)
|
|
|
|
def create_finding_event(
|
|
self,
|
|
finding: Dict[str, Any],
|
|
is_verified: bool = False,
|
|
) -> StreamEvent:
|
|
"""创建发现事件"""
|
|
event_type = StreamEventType.FINDING_VERIFIED if is_verified else StreamEventType.FINDING_NEW
|
|
|
|
return StreamEvent(
|
|
event_type=event_type,
|
|
sequence=self._next_sequence(),
|
|
node_name=self._current_node,
|
|
phase=self._current_phase,
|
|
data={
|
|
"title": finding.get("title", "Unknown"),
|
|
"severity": finding.get("severity", "medium"),
|
|
"vulnerability_type": finding.get("vulnerability_type", "other"),
|
|
"file_path": finding.get("file_path"),
|
|
"line_start": finding.get("line_start"),
|
|
"is_verified": is_verified,
|
|
"message": f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{finding.get('severity', 'medium').upper()}] {finding.get('title', 'Unknown')}",
|
|
},
|
|
)
|
|
|
|
def create_heartbeat(self) -> StreamEvent:
|
|
"""创建心跳事件"""
|
|
return StreamEvent(
|
|
event_type=StreamEventType.HEARTBEAT,
|
|
sequence=self._sequence, # 心跳不增加序列号
|
|
data={"message": "ping"},
|
|
)
|
|
|