""" Agent 事件管理器 负责事件的创建、存储和推送 """ import asyncio import json import logging from typing import Optional, Dict, Any, List, AsyncGenerator, Callable from datetime import datetime, timezone from dataclasses import dataclass import uuid logger = logging.getLogger(__name__) @dataclass class AgentEventData: """Agent 事件数据""" event_type: str phase: Optional[str] = None message: Optional[str] = None tool_name: Optional[str] = None tool_input: Optional[Dict[str, Any]] = None tool_output: Optional[Dict[str, Any]] = None tool_duration_ms: Optional[int] = None finding_id: Optional[str] = None tokens_used: int = 0 metadata: Optional[Dict[str, Any]] = None def to_dict(self) -> Dict[str, Any]: return { "event_type": self.event_type, "phase": self.phase, "message": self.message, "tool_name": self.tool_name, "tool_input": self.tool_input, "tool_output": self.tool_output, "tool_duration_ms": self.tool_duration_ms, "finding_id": self.finding_id, "tokens_used": self.tokens_used, "metadata": self.metadata, } class AgentEventEmitter: """ Agent 事件发射器 用于在 Agent 执行过程中发射事件 """ def __init__(self, task_id: str, event_manager: 'EventManager'): self.task_id = task_id self.event_manager = event_manager self._sequence = 0 self._current_phase = None async def emit(self, event_data: AgentEventData): """发射事件""" self._sequence += 1 event_data.phase = event_data.phase or self._current_phase await self.event_manager.add_event( task_id=self.task_id, sequence=self._sequence, **event_data.to_dict() ) async def emit_phase_start(self, phase: str, message: Optional[str] = None): """发射阶段开始事件""" self._current_phase = phase await self.emit(AgentEventData( event_type="phase_start", phase=phase, message=message or f"开始 {phase} 阶段", )) async def emit_phase_complete(self, phase: str, message: Optional[str] = None): """发射阶段完成事件""" await self.emit(AgentEventData( event_type="phase_complete", phase=phase, message=message or f"{phase} 阶段完成", )) async def emit_thinking(self, message: str, metadata: Optional[Dict] = None): """发射思考事件""" await self.emit(AgentEventData( event_type="thinking", message=message, metadata=metadata, )) async def emit_tool_call( self, tool_name: str, tool_input: Dict[str, Any], message: Optional[str] = None, ): """发射工具调用事件""" await self.emit(AgentEventData( event_type="tool_call", tool_name=tool_name, tool_input=tool_input, message=message or f"调用工具: {tool_name}", )) async def emit_tool_result( self, tool_name: str, tool_output: Any, duration_ms: int, message: Optional[str] = None, ): """发射工具结果事件""" # 处理输出,确保可序列化 if hasattr(tool_output, 'to_dict'): output_data = tool_output.to_dict() elif isinstance(tool_output, str): output_data = {"result": tool_output[:2000]} # 截断长输出 else: output_data = {"result": str(tool_output)[:2000]} await self.emit(AgentEventData( event_type="tool_result", tool_name=tool_name, tool_output=output_data, tool_duration_ms=duration_ms, message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)", )) async def emit_finding( self, finding_id: str, title: str, severity: str, vulnerability_type: str, is_verified: bool = False, ): """发射漏洞发现事件""" event_type = "finding_verified" if is_verified else "finding_new" await self.emit(AgentEventData( event_type=event_type, finding_id=finding_id, message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}", metadata={ "title": title, "severity": severity, "vulnerability_type": vulnerability_type, "is_verified": is_verified, }, )) async def emit_info(self, message: str, metadata: Optional[Dict] = None): """发射信息事件""" await self.emit(AgentEventData( event_type="info", message=message, metadata=metadata, )) async def emit_warning(self, message: str, metadata: Optional[Dict] = None): """发射警告事件""" await self.emit(AgentEventData( event_type="warning", message=message, metadata=metadata, )) async def emit_error(self, message: str, metadata: Optional[Dict] = None): """发射错误事件""" await self.emit(AgentEventData( event_type="error", message=message, metadata=metadata, )) async def emit_progress( self, current: int, total: int, message: Optional[str] = None, ): """发射进度事件""" percentage = (current / total * 100) if total > 0 else 0 await self.emit(AgentEventData( event_type="progress", message=message or f"进度: {current}/{total} ({percentage:.1f}%)", metadata={ "current": current, "total": total, "percentage": percentage, }, )) class EventManager: """ 事件管理器 负责事件的存储和检索 """ def __init__(self, db_session_factory=None): self.db_session_factory = db_session_factory self._event_queues: Dict[str, asyncio.Queue] = {} self._event_callbacks: Dict[str, List[Callable]] = {} async def add_event( self, task_id: str, event_type: str, sequence: int = 0, phase: Optional[str] = None, message: Optional[str] = None, tool_name: Optional[str] = None, tool_input: Optional[Dict] = None, tool_output: Optional[Dict] = None, tool_duration_ms: Optional[int] = None, finding_id: Optional[str] = None, tokens_used: int = 0, metadata: Optional[Dict] = None, ): """添加事件""" event_id = str(uuid.uuid4()) timestamp = datetime.now(timezone.utc) event_data = { "id": event_id, "task_id": task_id, "event_type": event_type, "sequence": sequence, "phase": phase, "message": message, "tool_name": tool_name, "tool_input": tool_input, "tool_output": tool_output, "tool_duration_ms": tool_duration_ms, "finding_id": finding_id, "tokens_used": tokens_used, "metadata": metadata, "timestamp": timestamp.isoformat(), } # 保存到数据库 if self.db_session_factory: try: await self._save_event_to_db(event_data) except Exception as e: logger.error(f"Failed to save event to database: {e}") # 推送到队列 if task_id in self._event_queues: await self._event_queues[task_id].put(event_data) # 调用回调 if task_id in self._event_callbacks: for callback in self._event_callbacks[task_id]: try: if asyncio.iscoroutinefunction(callback): await callback(event_data) else: callback(event_data) except Exception as e: logger.error(f"Event callback error: {e}") return event_id async def _save_event_to_db(self, event_data: Dict): """保存事件到数据库""" from app.models.agent_task import AgentEvent async with self.db_session_factory() as db: event = AgentEvent( id=event_data["id"], task_id=event_data["task_id"], event_type=event_data["event_type"], sequence=event_data["sequence"], phase=event_data["phase"], message=event_data["message"], tool_name=event_data["tool_name"], tool_input=event_data["tool_input"], tool_output=event_data["tool_output"], tool_duration_ms=event_data["tool_duration_ms"], finding_id=event_data["finding_id"], tokens_used=event_data["tokens_used"], event_metadata=event_data["metadata"], ) db.add(event) await db.commit() def create_queue(self, task_id: str) -> asyncio.Queue: """创建事件队列""" if task_id not in self._event_queues: self._event_queues[task_id] = asyncio.Queue() return self._event_queues[task_id] def remove_queue(self, task_id: str): """移除事件队列""" if task_id in self._event_queues: del self._event_queues[task_id] def add_callback(self, task_id: str, callback: Callable): """添加事件回调""" if task_id not in self._event_callbacks: self._event_callbacks[task_id] = [] self._event_callbacks[task_id].append(callback) def remove_callback(self, task_id: str, callback: Callable): """移除事件回调""" if task_id in self._event_callbacks: self._event_callbacks[task_id].remove(callback) async def get_events( self, task_id: str, after_sequence: int = 0, limit: int = 100, ) -> List[Dict]: """获取事件列表""" if not self.db_session_factory: return [] from sqlalchemy.future import select from app.models.agent_task import AgentEvent async with self.db_session_factory() as db: result = await db.execute( select(AgentEvent) .where(AgentEvent.task_id == task_id) .where(AgentEvent.sequence > after_sequence) .order_by(AgentEvent.sequence) .limit(limit) ) events = result.scalars().all() return [event.to_sse_dict() for event in events] async def stream_events( self, task_id: str, after_sequence: int = 0, ) -> AsyncGenerator[Dict, None]: """流式获取事件""" queue = self.create_queue(task_id) # 先发送历史事件 history = await self.get_events(task_id, after_sequence) for event in history: yield event # 然后实时推送新事件 try: while True: try: event = await asyncio.wait_for(queue.get(), timeout=30) yield event # 检查是否是结束事件 if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]: break except asyncio.TimeoutError: # 发送心跳 yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()} finally: self.remove_queue(task_id) def create_emitter(self, task_id: str) -> AgentEventEmitter: """创建事件发射器""" return AgentEventEmitter(task_id, self)