CodeReview/backend/app/services/agent/event_manager.py

372 lines
12 KiB
Python

"""
Agent 事件管理器
负责事件的创建、存储和推送
"""
import asyncio
import json
import logging
from typing import Optional, Dict, Any, List, AsyncGenerator, Callable
from datetime import datetime, timezone
from dataclasses import dataclass
import uuid
logger = logging.getLogger(__name__)
@dataclass
class AgentEventData:
"""Agent 事件数据"""
event_type: str
phase: Optional[str] = None
message: Optional[str] = None
tool_name: Optional[str] = None
tool_input: Optional[Dict[str, Any]] = None
tool_output: Optional[Dict[str, Any]] = None
tool_duration_ms: Optional[int] = None
finding_id: Optional[str] = None
tokens_used: int = 0
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
return {
"event_type": self.event_type,
"phase": self.phase,
"message": self.message,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output": self.tool_output,
"tool_duration_ms": self.tool_duration_ms,
"finding_id": self.finding_id,
"tokens_used": self.tokens_used,
"metadata": self.metadata,
}
class AgentEventEmitter:
"""
Agent 事件发射器
用于在 Agent 执行过程中发射事件
"""
def __init__(self, task_id: str, event_manager: 'EventManager'):
self.task_id = task_id
self.event_manager = event_manager
self._sequence = 0
self._current_phase = None
async def emit(self, event_data: AgentEventData):
"""发射事件"""
self._sequence += 1
event_data.phase = event_data.phase or self._current_phase
await self.event_manager.add_event(
task_id=self.task_id,
sequence=self._sequence,
**event_data.to_dict()
)
async def emit_phase_start(self, phase: str, message: Optional[str] = None):
"""发射阶段开始事件"""
self._current_phase = phase
await self.emit(AgentEventData(
event_type="phase_start",
phase=phase,
message=message or f"开始 {phase} 阶段",
))
async def emit_phase_complete(self, phase: str, message: Optional[str] = None):
"""发射阶段完成事件"""
await self.emit(AgentEventData(
event_type="phase_complete",
phase=phase,
message=message or f"{phase} 阶段完成",
))
async def emit_thinking(self, message: str, metadata: Optional[Dict] = None):
"""发射思考事件"""
await self.emit(AgentEventData(
event_type="thinking",
message=message,
metadata=metadata,
))
async def emit_tool_call(
self,
tool_name: str,
tool_input: Dict[str, Any],
message: Optional[str] = None,
):
"""发射工具调用事件"""
await self.emit(AgentEventData(
event_type="tool_call",
tool_name=tool_name,
tool_input=tool_input,
message=message or f"调用工具: {tool_name}",
))
async def emit_tool_result(
self,
tool_name: str,
tool_output: Any,
duration_ms: int,
message: Optional[str] = None,
):
"""发射工具结果事件"""
# 处理输出,确保可序列化
if hasattr(tool_output, 'to_dict'):
output_data = tool_output.to_dict()
elif isinstance(tool_output, str):
output_data = {"result": tool_output[:2000]} # 截断长输出
else:
output_data = {"result": str(tool_output)[:2000]}
await self.emit(AgentEventData(
event_type="tool_result",
tool_name=tool_name,
tool_output=output_data,
tool_duration_ms=duration_ms,
message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)",
))
async def emit_finding(
self,
finding_id: str,
title: str,
severity: str,
vulnerability_type: str,
is_verified: bool = False,
):
"""发射漏洞发现事件"""
event_type = "finding_verified" if is_verified else "finding_new"
await self.emit(AgentEventData(
event_type=event_type,
finding_id=finding_id,
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
metadata={
"title": title,
"severity": severity,
"vulnerability_type": vulnerability_type,
"is_verified": is_verified,
},
))
async def emit_info(self, message: str, metadata: Optional[Dict] = None):
"""发射信息事件"""
await self.emit(AgentEventData(
event_type="info",
message=message,
metadata=metadata,
))
async def emit_warning(self, message: str, metadata: Optional[Dict] = None):
"""发射警告事件"""
await self.emit(AgentEventData(
event_type="warning",
message=message,
metadata=metadata,
))
async def emit_error(self, message: str, metadata: Optional[Dict] = None):
"""发射错误事件"""
await self.emit(AgentEventData(
event_type="error",
message=message,
metadata=metadata,
))
async def emit_progress(
self,
current: int,
total: int,
message: Optional[str] = None,
):
"""发射进度事件"""
percentage = (current / total * 100) if total > 0 else 0
await self.emit(AgentEventData(
event_type="progress",
message=message or f"进度: {current}/{total} ({percentage:.1f}%)",
metadata={
"current": current,
"total": total,
"percentage": percentage,
},
))
class EventManager:
"""
事件管理器
负责事件的存储和检索
"""
def __init__(self, db_session_factory=None):
self.db_session_factory = db_session_factory
self._event_queues: Dict[str, asyncio.Queue] = {}
self._event_callbacks: Dict[str, List[Callable]] = {}
async def add_event(
self,
task_id: str,
event_type: str,
sequence: int = 0,
phase: Optional[str] = None,
message: Optional[str] = None,
tool_name: Optional[str] = None,
tool_input: Optional[Dict] = None,
tool_output: Optional[Dict] = None,
tool_duration_ms: Optional[int] = None,
finding_id: Optional[str] = None,
tokens_used: int = 0,
metadata: Optional[Dict] = None,
):
"""添加事件"""
event_id = str(uuid.uuid4())
timestamp = datetime.now(timezone.utc)
event_data = {
"id": event_id,
"task_id": task_id,
"event_type": event_type,
"sequence": sequence,
"phase": phase,
"message": message,
"tool_name": tool_name,
"tool_input": tool_input,
"tool_output": tool_output,
"tool_duration_ms": tool_duration_ms,
"finding_id": finding_id,
"tokens_used": tokens_used,
"metadata": metadata,
"timestamp": timestamp.isoformat(),
}
# 保存到数据库
if self.db_session_factory:
try:
await self._save_event_to_db(event_data)
except Exception as e:
logger.error(f"Failed to save event to database: {e}")
# 推送到队列
if task_id in self._event_queues:
await self._event_queues[task_id].put(event_data)
# 调用回调
if task_id in self._event_callbacks:
for callback in self._event_callbacks[task_id]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(event_data)
else:
callback(event_data)
except Exception as e:
logger.error(f"Event callback error: {e}")
return event_id
async def _save_event_to_db(self, event_data: Dict):
"""保存事件到数据库"""
from app.models.agent_task import AgentEvent
async with self.db_session_factory() as db:
event = AgentEvent(
id=event_data["id"],
task_id=event_data["task_id"],
event_type=event_data["event_type"],
sequence=event_data["sequence"],
phase=event_data["phase"],
message=event_data["message"],
tool_name=event_data["tool_name"],
tool_input=event_data["tool_input"],
tool_output=event_data["tool_output"],
tool_duration_ms=event_data["tool_duration_ms"],
finding_id=event_data["finding_id"],
tokens_used=event_data["tokens_used"],
event_metadata=event_data["metadata"],
)
db.add(event)
await db.commit()
def create_queue(self, task_id: str) -> asyncio.Queue:
"""创建事件队列"""
if task_id not in self._event_queues:
self._event_queues[task_id] = asyncio.Queue()
return self._event_queues[task_id]
def remove_queue(self, task_id: str):
"""移除事件队列"""
if task_id in self._event_queues:
del self._event_queues[task_id]
def add_callback(self, task_id: str, callback: Callable):
"""添加事件回调"""
if task_id not in self._event_callbacks:
self._event_callbacks[task_id] = []
self._event_callbacks[task_id].append(callback)
def remove_callback(self, task_id: str, callback: Callable):
"""移除事件回调"""
if task_id in self._event_callbacks:
self._event_callbacks[task_id].remove(callback)
async def get_events(
self,
task_id: str,
after_sequence: int = 0,
limit: int = 100,
) -> List[Dict]:
"""获取事件列表"""
if not self.db_session_factory:
return []
from sqlalchemy.future import select
from app.models.agent_task import AgentEvent
async with self.db_session_factory() as db:
result = await db.execute(
select(AgentEvent)
.where(AgentEvent.task_id == task_id)
.where(AgentEvent.sequence > after_sequence)
.order_by(AgentEvent.sequence)
.limit(limit)
)
events = result.scalars().all()
return [event.to_sse_dict() for event in events]
async def stream_events(
self,
task_id: str,
after_sequence: int = 0,
) -> AsyncGenerator[Dict, None]:
"""流式获取事件"""
queue = self.create_queue(task_id)
# 先发送历史事件
history = await self.get_events(task_id, after_sequence)
for event in history:
yield event
# 然后实时推送新事件
try:
while True:
try:
event = await asyncio.wait_for(queue.get(), timeout=30)
yield event
# 检查是否是结束事件
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
break
except asyncio.TimeoutError:
# 发送心跳
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
finally:
self.remove_queue(task_id)
def create_emitter(self, task_id: str) -> AgentEventEmitter:
"""创建事件发射器"""
return AgentEventEmitter(task_id, self)