548 lines
20 KiB
Python
548 lines
20 KiB
Python
"""
|
||
Agent 事件管理器
|
||
负责事件的创建、存储和推送
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import Optional, Dict, Any, List, AsyncGenerator, Callable
|
||
from datetime import datetime, timezone
|
||
from dataclasses import dataclass
|
||
import uuid
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AgentEventData:
|
||
"""Agent 事件数据"""
|
||
event_type: str
|
||
phase: Optional[str] = None
|
||
message: Optional[str] = None
|
||
tool_name: Optional[str] = None
|
||
tool_input: Optional[Dict[str, Any]] = None
|
||
tool_output: Optional[Dict[str, Any]] = None
|
||
tool_duration_ms: Optional[int] = None
|
||
finding_id: Optional[str] = None
|
||
tokens_used: int = 0
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"event_type": self.event_type,
|
||
"phase": self.phase,
|
||
"message": self.message,
|
||
"tool_name": self.tool_name,
|
||
"tool_input": self.tool_input,
|
||
"tool_output": self.tool_output,
|
||
"tool_duration_ms": self.tool_duration_ms,
|
||
"finding_id": self.finding_id,
|
||
"tokens_used": self.tokens_used,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
|
||
class AgentEventEmitter:
|
||
"""
|
||
Agent 事件发射器
|
||
用于在 Agent 执行过程中发射事件
|
||
"""
|
||
|
||
def __init__(self, task_id: str, event_manager: 'EventManager'):
|
||
self.task_id = task_id
|
||
self.event_manager = event_manager
|
||
self._sequence = 0
|
||
self._current_phase = None
|
||
|
||
async def emit(self, event_data: AgentEventData):
|
||
"""发射事件"""
|
||
self._sequence += 1
|
||
event_data.phase = event_data.phase or self._current_phase
|
||
|
||
await self.event_manager.add_event(
|
||
task_id=self.task_id,
|
||
sequence=self._sequence,
|
||
**event_data.to_dict()
|
||
)
|
||
|
||
async def emit_phase_start(self, phase: str, message: Optional[str] = None):
|
||
"""发射阶段开始事件"""
|
||
self._current_phase = phase
|
||
await self.emit(AgentEventData(
|
||
event_type="phase_start",
|
||
phase=phase,
|
||
message=message or f"开始 {phase} 阶段",
|
||
))
|
||
|
||
async def emit_phase_complete(self, phase: str, message: Optional[str] = None):
|
||
"""发射阶段完成事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="phase_complete",
|
||
phase=phase,
|
||
message=message or f"{phase} 阶段完成",
|
||
))
|
||
|
||
async def emit_thinking(self, message: str, metadata: Optional[Dict] = None):
|
||
"""发射思考事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="thinking",
|
||
message=message,
|
||
metadata=metadata,
|
||
))
|
||
|
||
async def emit_llm_thought(self, thought: str, iteration: int = 0):
|
||
"""发射 LLM 思考内容事件 - 核心!展示 LLM 在想什么"""
|
||
display = thought[:500] + "..." if len(thought) > 500 else thought
|
||
await self.emit(AgentEventData(
|
||
event_type="llm_thought",
|
||
message=f"💭 LLM 思考:\n{display}",
|
||
metadata={"thought": thought, "iteration": iteration},
|
||
))
|
||
|
||
async def emit_llm_decision(self, decision: str, reason: str = ""):
|
||
"""发射 LLM 决策事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="llm_decision",
|
||
message=f"💡 LLM 决策: {decision}" + (f" ({reason})" if reason else ""),
|
||
metadata={"decision": decision, "reason": reason},
|
||
))
|
||
|
||
async def emit_llm_action(self, action: str, action_input: Dict):
|
||
"""发射 LLM 动作事件"""
|
||
import json
|
||
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
|
||
await self.emit(AgentEventData(
|
||
event_type="llm_action",
|
||
message=f"⚡ LLM 动作: {action}\n 参数: {input_str}",
|
||
metadata={"action": action, "action_input": action_input},
|
||
))
|
||
|
||
async def emit_tool_call(
|
||
self,
|
||
tool_name: str,
|
||
tool_input: Dict[str, Any],
|
||
message: Optional[str] = None,
|
||
):
|
||
"""发射工具调用事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="tool_call",
|
||
tool_name=tool_name,
|
||
tool_input=tool_input,
|
||
message=message or f"调用工具: {tool_name}",
|
||
))
|
||
|
||
async def emit_tool_result(
|
||
self,
|
||
tool_name: str,
|
||
tool_output: Any,
|
||
duration_ms: int,
|
||
message: Optional[str] = None,
|
||
):
|
||
"""发射工具结果事件"""
|
||
# 处理输出,确保可序列化
|
||
if hasattr(tool_output, 'to_dict'):
|
||
output_data = tool_output.to_dict()
|
||
elif isinstance(tool_output, str):
|
||
output_data = {"result": tool_output[:2000]} # 截断长输出
|
||
else:
|
||
output_data = {"result": str(tool_output)[:2000]}
|
||
|
||
await self.emit(AgentEventData(
|
||
event_type="tool_result",
|
||
tool_name=tool_name,
|
||
tool_output=output_data,
|
||
tool_duration_ms=duration_ms,
|
||
message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)",
|
||
))
|
||
|
||
async def emit_finding(
|
||
self,
|
||
finding_id: str,
|
||
title: str,
|
||
severity: str,
|
||
vulnerability_type: str,
|
||
is_verified: bool = False,
|
||
):
|
||
"""发射漏洞发现事件"""
|
||
event_type = "finding_verified" if is_verified else "finding_new"
|
||
await self.emit(AgentEventData(
|
||
event_type=event_type,
|
||
finding_id=finding_id,
|
||
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
|
||
metadata={
|
||
"id": finding_id, # 🔥 添加 id 字段供前端使用
|
||
"title": title,
|
||
"severity": severity,
|
||
"vulnerability_type": vulnerability_type,
|
||
"is_verified": is_verified,
|
||
},
|
||
))
|
||
|
||
async def emit_info(self, message: str, metadata: Optional[Dict] = None):
|
||
"""发射信息事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="info",
|
||
message=message,
|
||
metadata=metadata,
|
||
))
|
||
|
||
async def emit_warning(self, message: str, metadata: Optional[Dict] = None):
|
||
"""发射警告事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="warning",
|
||
message=message,
|
||
metadata=metadata,
|
||
))
|
||
|
||
async def emit_error(self, message: str, metadata: Optional[Dict] = None):
|
||
"""发射错误事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="error",
|
||
message=message,
|
||
metadata=metadata,
|
||
))
|
||
|
||
async def emit_progress(
|
||
self,
|
||
current: int,
|
||
total: int,
|
||
message: Optional[str] = None,
|
||
):
|
||
"""发射进度事件"""
|
||
percentage = (current / total * 100) if total > 0 else 0
|
||
await self.emit(AgentEventData(
|
||
event_type="progress",
|
||
message=message or f"进度: {current}/{total} ({percentage:.1f}%)",
|
||
metadata={
|
||
"current": current,
|
||
"total": total,
|
||
"percentage": percentage,
|
||
},
|
||
))
|
||
|
||
async def emit_task_complete(
|
||
self,
|
||
findings_count: int,
|
||
duration_ms: int,
|
||
message: Optional[str] = None,
|
||
):
|
||
"""发射任务完成事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="task_complete",
|
||
message=message or f"✅ 审计完成!发现 {findings_count} 个漏洞,耗时 {duration_ms/1000:.1f}秒",
|
||
metadata={
|
||
"findings_count": findings_count,
|
||
"duration_ms": duration_ms,
|
||
},
|
||
))
|
||
|
||
async def emit_task_error(self, error: str, message: Optional[str] = None):
|
||
"""发射任务错误事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="task_error",
|
||
message=message or f"❌ 任务失败: {error}",
|
||
metadata={"error": error},
|
||
))
|
||
|
||
async def emit_task_cancelled(self, message: Optional[str] = None):
|
||
"""发射任务取消事件"""
|
||
await self.emit(AgentEventData(
|
||
event_type="task_cancel",
|
||
message=message or "⚠️ 任务已取消",
|
||
))
|
||
|
||
|
||
class EventManager:
|
||
"""
|
||
事件管理器
|
||
负责事件的存储和检索
|
||
"""
|
||
|
||
def __init__(self, db_session_factory=None):
|
||
self.db_session_factory = db_session_factory
|
||
self._event_queues: Dict[str, asyncio.Queue] = {}
|
||
self._event_callbacks: Dict[str, List[Callable]] = {}
|
||
|
||
async def add_event(
|
||
self,
|
||
task_id: str,
|
||
event_type: str,
|
||
sequence: int = 0,
|
||
phase: Optional[str] = None,
|
||
message: Optional[str] = None,
|
||
tool_name: Optional[str] = None,
|
||
tool_input: Optional[Dict] = None,
|
||
tool_output: Optional[Dict] = None,
|
||
tool_duration_ms: Optional[int] = None,
|
||
finding_id: Optional[str] = None,
|
||
tokens_used: int = 0,
|
||
metadata: Optional[Dict] = None,
|
||
):
|
||
"""添加事件"""
|
||
event_id = str(uuid.uuid4())
|
||
timestamp = datetime.now(timezone.utc)
|
||
|
||
event_data = {
|
||
"id": event_id,
|
||
"task_id": task_id,
|
||
"event_type": event_type,
|
||
"sequence": sequence,
|
||
"phase": phase,
|
||
"message": message,
|
||
"tool_name": tool_name,
|
||
"tool_input": tool_input,
|
||
"tool_output": tool_output,
|
||
"tool_duration_ms": tool_duration_ms,
|
||
"finding_id": finding_id,
|
||
"tokens_used": tokens_used,
|
||
"metadata": metadata,
|
||
"timestamp": timestamp.isoformat(),
|
||
}
|
||
|
||
# 保存到数据库(跳过高频事件如 thinking_token)
|
||
skip_db_events = {"thinking_token"}
|
||
if self.db_session_factory and event_type not in skip_db_events:
|
||
try:
|
||
await self._save_event_to_db(event_data)
|
||
except Exception as e:
|
||
logger.error(f"Failed to save event to database: {e}")
|
||
|
||
# 推送到队列(非阻塞)
|
||
if task_id in self._event_queues:
|
||
try:
|
||
self._event_queues[task_id].put_nowait(event_data)
|
||
# 🔥 DEBUG: 记录重要事件被添加到队列
|
||
if event_type in ["thinking_start", "thinking_end", "dispatch", "task_complete", "task_error"]:
|
||
logger.info(f"[EventQueue] Added {event_type} to queue for task {task_id}, queue size: {self._event_queues[task_id].qsize()}")
|
||
elif event_type == "thinking_token":
|
||
# 每10个token记录一次
|
||
if sequence % 10 == 0:
|
||
logger.debug(f"[EventQueue] Added thinking_token #{sequence} to queue, size: {self._event_queues[task_id].qsize()}")
|
||
except asyncio.QueueFull:
|
||
logger.warning(f"Event queue full for task {task_id}, dropping event: {event_type}")
|
||
|
||
# 调用回调
|
||
if task_id in self._event_callbacks:
|
||
for callback in self._event_callbacks[task_id]:
|
||
try:
|
||
if asyncio.iscoroutinefunction(callback):
|
||
await callback(event_data)
|
||
else:
|
||
callback(event_data)
|
||
except Exception as e:
|
||
logger.error(f"Event callback error: {e}")
|
||
|
||
return event_id
|
||
|
||
async def _save_event_to_db(self, event_data: Dict):
|
||
"""保存事件到数据库"""
|
||
from app.models.agent_task import AgentEvent
|
||
|
||
# 🔥 清理无效的 UTF-8 字符(如二进制内容)
|
||
def sanitize_string(s):
|
||
"""清理字符串中的无效 UTF-8 字符"""
|
||
if s is None:
|
||
return None
|
||
if not isinstance(s, str):
|
||
s = str(s)
|
||
# 移除 NULL 字节和其他不可打印的控制字符(保留换行和制表符)
|
||
return ''.join(
|
||
char for char in s
|
||
if char in '\n\r\t' or (ord(char) >= 32 and ord(char) != 127)
|
||
)
|
||
|
||
def sanitize_dict(d):
|
||
"""递归清理字典中的字符串值"""
|
||
if d is None:
|
||
return None
|
||
if isinstance(d, dict):
|
||
return {k: sanitize_dict(v) for k, v in d.items()}
|
||
elif isinstance(d, list):
|
||
return [sanitize_dict(item) for item in d]
|
||
elif isinstance(d, str):
|
||
return sanitize_string(d)
|
||
return d
|
||
|
||
async with self.db_session_factory() as db:
|
||
event = AgentEvent(
|
||
id=event_data["id"],
|
||
task_id=event_data["task_id"],
|
||
event_type=event_data["event_type"],
|
||
sequence=event_data["sequence"],
|
||
phase=event_data["phase"],
|
||
message=sanitize_string(event_data["message"]), # 🔥 清理消息
|
||
tool_name=event_data["tool_name"],
|
||
tool_input=sanitize_dict(event_data["tool_input"]), # 🔥 清理工具输入
|
||
tool_output=sanitize_dict(event_data["tool_output"]), # 🔥 清理工具输出
|
||
tool_duration_ms=event_data["tool_duration_ms"],
|
||
finding_id=event_data["finding_id"],
|
||
tokens_used=event_data["tokens_used"],
|
||
event_metadata=sanitize_dict(event_data["metadata"]), # 🔥 清理元数据
|
||
)
|
||
db.add(event)
|
||
await db.commit()
|
||
|
||
def create_queue(self, task_id: str) -> asyncio.Queue:
|
||
"""创建或获取事件队列"""
|
||
if task_id not in self._event_queues:
|
||
# 🔥 使用较大的队列容量,缓存更多 token 事件
|
||
self._event_queues[task_id] = asyncio.Queue(maxsize=5000)
|
||
return self._event_queues[task_id]
|
||
|
||
def remove_queue(self, task_id: str):
|
||
"""移除事件队列"""
|
||
if task_id in self._event_queues:
|
||
del self._event_queues[task_id]
|
||
|
||
def add_callback(self, task_id: str, callback: Callable):
|
||
"""添加事件回调"""
|
||
if task_id not in self._event_callbacks:
|
||
self._event_callbacks[task_id] = []
|
||
self._event_callbacks[task_id].append(callback)
|
||
|
||
def remove_callback(self, task_id: str, callback: Callable):
|
||
"""移除事件回调"""
|
||
if task_id in self._event_callbacks:
|
||
self._event_callbacks[task_id].remove(callback)
|
||
|
||
async def get_events(
|
||
self,
|
||
task_id: str,
|
||
after_sequence: int = 0,
|
||
limit: int = 100,
|
||
) -> List[Dict]:
|
||
"""获取事件列表"""
|
||
if not self.db_session_factory:
|
||
return []
|
||
|
||
from sqlalchemy.future import select
|
||
from app.models.agent_task import AgentEvent
|
||
|
||
async with self.db_session_factory() as db:
|
||
result = await db.execute(
|
||
select(AgentEvent)
|
||
.where(AgentEvent.task_id == task_id)
|
||
.where(AgentEvent.sequence > after_sequence)
|
||
.order_by(AgentEvent.sequence)
|
||
.limit(limit)
|
||
)
|
||
events = result.scalars().all()
|
||
return [event.to_sse_dict() for event in events]
|
||
|
||
async def stream_events(
|
||
self,
|
||
task_id: str,
|
||
after_sequence: int = 0,
|
||
) -> AsyncGenerator[Dict, None]:
|
||
"""流式获取事件
|
||
|
||
🔥 重要: 此方法会先排空队列中已缓存的事件(在 SSE 连接前产生的),
|
||
然后继续实时推送新事件。
|
||
只返回序列号 > after_sequence 的事件。
|
||
"""
|
||
logger.info(f"[StreamEvents] Task {task_id}: Starting stream with after_sequence={after_sequence}")
|
||
|
||
# 获取现有队列(由 AgentRunner 在初始化时创建)
|
||
queue = self._event_queues.get(task_id)
|
||
if not queue:
|
||
# 如果队列不存在,创建一个新的(回退逻辑)
|
||
queue = self.create_queue(task_id)
|
||
logger.warning(f"Queue not found for task {task_id}, created new one")
|
||
|
||
# 🔥 CRITICAL FIX: 记录当前队列大小,只消耗这些已存在的事件
|
||
# 之前的 bug: while not queue.empty() 会永远循环,因为 LLM 持续添加事件
|
||
initial_queue_size = queue.qsize()
|
||
logger.info(f"[StreamEvents] Task {task_id}: Draining {initial_queue_size} buffered events...")
|
||
|
||
# 🔥 先排空队列中已缓存的事件(只消耗连接时已存在的事件数量)
|
||
buffered_count = 0
|
||
skipped_count = 0
|
||
max_drain = initial_queue_size # 只消耗这么多事件,避免无限循环
|
||
|
||
for _ in range(max_drain):
|
||
try:
|
||
buffered_event = queue.get_nowait()
|
||
|
||
# 🔥 过滤掉序列号 <= after_sequence 的事件
|
||
event_sequence = buffered_event.get("sequence", 0)
|
||
if event_sequence <= after_sequence:
|
||
skipped_count += 1
|
||
continue
|
||
|
||
buffered_count += 1
|
||
yield buffered_event
|
||
|
||
# 🔥 为缓存事件添加小延迟,但比之前少很多(避免拖慢)
|
||
event_type = buffered_event.get("event_type")
|
||
if event_type == "thinking_token":
|
||
await asyncio.sleep(0.005) # 5ms for tokens (reduced from 15ms)
|
||
# 其他事件不加延迟,快速发送
|
||
|
||
# 检查是否是结束事件
|
||
if event_type in ["task_complete", "task_error", "task_cancel"]:
|
||
logger.info(f"[StreamEvents] Task {task_id} already completed, sent {buffered_count} buffered events (skipped {skipped_count})")
|
||
return
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
if buffered_count > 0 or skipped_count > 0:
|
||
logger.info(f"[StreamEvents] Task {task_id}: Drained {buffered_count} buffered events, skipped {skipped_count}")
|
||
|
||
# 🔥 DEBUG: 记录进入实时循环
|
||
logger.info(f"[StreamEvents] Task {task_id}: Entering real-time loop, queue size: {queue.qsize()}")
|
||
|
||
# 然后实时推送新事件
|
||
try:
|
||
while True:
|
||
try:
|
||
logger.debug(f"[StreamEvents] Task {task_id}: Waiting for next event from queue...")
|
||
event = await asyncio.wait_for(queue.get(), timeout=30)
|
||
logger.debug(f"[StreamEvents] Task {task_id}: Got event from queue: {event.get('event_type')}")
|
||
|
||
# 🔥 过滤掉序列号 <= after_sequence 的事件
|
||
event_sequence = event.get("sequence", 0)
|
||
if event_sequence <= after_sequence:
|
||
logger.debug(f"[StreamEvents] Task {task_id}: Skipping event seq={event_sequence} (after_sequence={after_sequence})")
|
||
continue
|
||
|
||
# 🔥 DEBUG: 记录重要事件被发送
|
||
event_type = event.get("event_type")
|
||
if event_type in ["thinking_start", "thinking_end", "dispatch", "task_complete", "task_error"]:
|
||
logger.info(f"[StreamEvents] Yielding {event_type} (seq={event_sequence}) for task {task_id}")
|
||
|
||
yield event
|
||
|
||
# 🔥 为 thinking_token 添加微延迟确保流式效果
|
||
if event_type == "thinking_token":
|
||
await asyncio.sleep(0.01) # 10ms
|
||
|
||
# 检查是否是结束事件
|
||
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
|
||
break
|
||
|
||
except asyncio.TimeoutError:
|
||
# 发送心跳
|
||
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||
|
||
except GeneratorExit:
|
||
# SSE 连接断开
|
||
logger.debug(f"SSE stream closed for task {task_id}")
|
||
# 🔥 不要移除队列,让 AgentRunner 管理队列的生命周期
|
||
|
||
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
||
"""创建事件发射器"""
|
||
return AgentEventEmitter(task_id, self)
|
||
|
||
async def close(self):
|
||
"""关闭事件管理器,清理资源"""
|
||
# 清理所有队列
|
||
for task_id in list(self._event_queues.keys()):
|
||
self.remove_queue(task_id)
|
||
|
||
# 清理所有回调
|
||
self._event_callbacks.clear()
|
||
|
||
logger.debug("EventManager closed")
|
||
|