372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""
|
|
Agent 事件管理器
|
|
负责事件的创建、存储和推送
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Optional, Dict, Any, List, AsyncGenerator, Callable
|
|
from datetime import datetime, timezone
|
|
from dataclasses import dataclass
|
|
import uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AgentEventData:
|
|
"""Agent 事件数据"""
|
|
event_type: str
|
|
phase: Optional[str] = None
|
|
message: Optional[str] = None
|
|
tool_name: Optional[str] = None
|
|
tool_input: Optional[Dict[str, Any]] = None
|
|
tool_output: Optional[Dict[str, Any]] = None
|
|
tool_duration_ms: Optional[int] = None
|
|
finding_id: Optional[str] = None
|
|
tokens_used: int = 0
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"event_type": self.event_type,
|
|
"phase": self.phase,
|
|
"message": self.message,
|
|
"tool_name": self.tool_name,
|
|
"tool_input": self.tool_input,
|
|
"tool_output": self.tool_output,
|
|
"tool_duration_ms": self.tool_duration_ms,
|
|
"finding_id": self.finding_id,
|
|
"tokens_used": self.tokens_used,
|
|
"metadata": self.metadata,
|
|
}
|
|
|
|
|
|
class AgentEventEmitter:
|
|
"""
|
|
Agent 事件发射器
|
|
用于在 Agent 执行过程中发射事件
|
|
"""
|
|
|
|
def __init__(self, task_id: str, event_manager: 'EventManager'):
|
|
self.task_id = task_id
|
|
self.event_manager = event_manager
|
|
self._sequence = 0
|
|
self._current_phase = None
|
|
|
|
async def emit(self, event_data: AgentEventData):
|
|
"""发射事件"""
|
|
self._sequence += 1
|
|
event_data.phase = event_data.phase or self._current_phase
|
|
|
|
await self.event_manager.add_event(
|
|
task_id=self.task_id,
|
|
sequence=self._sequence,
|
|
**event_data.to_dict()
|
|
)
|
|
|
|
async def emit_phase_start(self, phase: str, message: Optional[str] = None):
|
|
"""发射阶段开始事件"""
|
|
self._current_phase = phase
|
|
await self.emit(AgentEventData(
|
|
event_type="phase_start",
|
|
phase=phase,
|
|
message=message or f"开始 {phase} 阶段",
|
|
))
|
|
|
|
async def emit_phase_complete(self, phase: str, message: Optional[str] = None):
|
|
"""发射阶段完成事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="phase_complete",
|
|
phase=phase,
|
|
message=message or f"{phase} 阶段完成",
|
|
))
|
|
|
|
async def emit_thinking(self, message: str, metadata: Optional[Dict] = None):
|
|
"""发射思考事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="thinking",
|
|
message=message,
|
|
metadata=metadata,
|
|
))
|
|
|
|
async def emit_tool_call(
|
|
self,
|
|
tool_name: str,
|
|
tool_input: Dict[str, Any],
|
|
message: Optional[str] = None,
|
|
):
|
|
"""发射工具调用事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="tool_call",
|
|
tool_name=tool_name,
|
|
tool_input=tool_input,
|
|
message=message or f"调用工具: {tool_name}",
|
|
))
|
|
|
|
async def emit_tool_result(
|
|
self,
|
|
tool_name: str,
|
|
tool_output: Any,
|
|
duration_ms: int,
|
|
message: Optional[str] = None,
|
|
):
|
|
"""发射工具结果事件"""
|
|
# 处理输出,确保可序列化
|
|
if hasattr(tool_output, 'to_dict'):
|
|
output_data = tool_output.to_dict()
|
|
elif isinstance(tool_output, str):
|
|
output_data = {"result": tool_output[:2000]} # 截断长输出
|
|
else:
|
|
output_data = {"result": str(tool_output)[:2000]}
|
|
|
|
await self.emit(AgentEventData(
|
|
event_type="tool_result",
|
|
tool_name=tool_name,
|
|
tool_output=output_data,
|
|
tool_duration_ms=duration_ms,
|
|
message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)",
|
|
))
|
|
|
|
async def emit_finding(
|
|
self,
|
|
finding_id: str,
|
|
title: str,
|
|
severity: str,
|
|
vulnerability_type: str,
|
|
is_verified: bool = False,
|
|
):
|
|
"""发射漏洞发现事件"""
|
|
event_type = "finding_verified" if is_verified else "finding_new"
|
|
await self.emit(AgentEventData(
|
|
event_type=event_type,
|
|
finding_id=finding_id,
|
|
message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}",
|
|
metadata={
|
|
"title": title,
|
|
"severity": severity,
|
|
"vulnerability_type": vulnerability_type,
|
|
"is_verified": is_verified,
|
|
},
|
|
))
|
|
|
|
async def emit_info(self, message: str, metadata: Optional[Dict] = None):
|
|
"""发射信息事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="info",
|
|
message=message,
|
|
metadata=metadata,
|
|
))
|
|
|
|
async def emit_warning(self, message: str, metadata: Optional[Dict] = None):
|
|
"""发射警告事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="warning",
|
|
message=message,
|
|
metadata=metadata,
|
|
))
|
|
|
|
async def emit_error(self, message: str, metadata: Optional[Dict] = None):
|
|
"""发射错误事件"""
|
|
await self.emit(AgentEventData(
|
|
event_type="error",
|
|
message=message,
|
|
metadata=metadata,
|
|
))
|
|
|
|
async def emit_progress(
|
|
self,
|
|
current: int,
|
|
total: int,
|
|
message: Optional[str] = None,
|
|
):
|
|
"""发射进度事件"""
|
|
percentage = (current / total * 100) if total > 0 else 0
|
|
await self.emit(AgentEventData(
|
|
event_type="progress",
|
|
message=message or f"进度: {current}/{total} ({percentage:.1f}%)",
|
|
metadata={
|
|
"current": current,
|
|
"total": total,
|
|
"percentage": percentage,
|
|
},
|
|
))
|
|
|
|
|
|
class EventManager:
|
|
"""
|
|
事件管理器
|
|
负责事件的存储和检索
|
|
"""
|
|
|
|
def __init__(self, db_session_factory=None):
|
|
self.db_session_factory = db_session_factory
|
|
self._event_queues: Dict[str, asyncio.Queue] = {}
|
|
self._event_callbacks: Dict[str, List[Callable]] = {}
|
|
|
|
async def add_event(
|
|
self,
|
|
task_id: str,
|
|
event_type: str,
|
|
sequence: int = 0,
|
|
phase: Optional[str] = None,
|
|
message: Optional[str] = None,
|
|
tool_name: Optional[str] = None,
|
|
tool_input: Optional[Dict] = None,
|
|
tool_output: Optional[Dict] = None,
|
|
tool_duration_ms: Optional[int] = None,
|
|
finding_id: Optional[str] = None,
|
|
tokens_used: int = 0,
|
|
metadata: Optional[Dict] = None,
|
|
):
|
|
"""添加事件"""
|
|
event_id = str(uuid.uuid4())
|
|
timestamp = datetime.now(timezone.utc)
|
|
|
|
event_data = {
|
|
"id": event_id,
|
|
"task_id": task_id,
|
|
"event_type": event_type,
|
|
"sequence": sequence,
|
|
"phase": phase,
|
|
"message": message,
|
|
"tool_name": tool_name,
|
|
"tool_input": tool_input,
|
|
"tool_output": tool_output,
|
|
"tool_duration_ms": tool_duration_ms,
|
|
"finding_id": finding_id,
|
|
"tokens_used": tokens_used,
|
|
"metadata": metadata,
|
|
"timestamp": timestamp.isoformat(),
|
|
}
|
|
|
|
# 保存到数据库
|
|
if self.db_session_factory:
|
|
try:
|
|
await self._save_event_to_db(event_data)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save event to database: {e}")
|
|
|
|
# 推送到队列
|
|
if task_id in self._event_queues:
|
|
await self._event_queues[task_id].put(event_data)
|
|
|
|
# 调用回调
|
|
if task_id in self._event_callbacks:
|
|
for callback in self._event_callbacks[task_id]:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
await callback(event_data)
|
|
else:
|
|
callback(event_data)
|
|
except Exception as e:
|
|
logger.error(f"Event callback error: {e}")
|
|
|
|
return event_id
|
|
|
|
async def _save_event_to_db(self, event_data: Dict):
|
|
"""保存事件到数据库"""
|
|
from app.models.agent_task import AgentEvent
|
|
|
|
async with self.db_session_factory() as db:
|
|
event = AgentEvent(
|
|
id=event_data["id"],
|
|
task_id=event_data["task_id"],
|
|
event_type=event_data["event_type"],
|
|
sequence=event_data["sequence"],
|
|
phase=event_data["phase"],
|
|
message=event_data["message"],
|
|
tool_name=event_data["tool_name"],
|
|
tool_input=event_data["tool_input"],
|
|
tool_output=event_data["tool_output"],
|
|
tool_duration_ms=event_data["tool_duration_ms"],
|
|
finding_id=event_data["finding_id"],
|
|
tokens_used=event_data["tokens_used"],
|
|
event_metadata=event_data["metadata"],
|
|
)
|
|
db.add(event)
|
|
await db.commit()
|
|
|
|
def create_queue(self, task_id: str) -> asyncio.Queue:
|
|
"""创建事件队列"""
|
|
if task_id not in self._event_queues:
|
|
self._event_queues[task_id] = asyncio.Queue()
|
|
return self._event_queues[task_id]
|
|
|
|
def remove_queue(self, task_id: str):
|
|
"""移除事件队列"""
|
|
if task_id in self._event_queues:
|
|
del self._event_queues[task_id]
|
|
|
|
def add_callback(self, task_id: str, callback: Callable):
|
|
"""添加事件回调"""
|
|
if task_id not in self._event_callbacks:
|
|
self._event_callbacks[task_id] = []
|
|
self._event_callbacks[task_id].append(callback)
|
|
|
|
def remove_callback(self, task_id: str, callback: Callable):
|
|
"""移除事件回调"""
|
|
if task_id in self._event_callbacks:
|
|
self._event_callbacks[task_id].remove(callback)
|
|
|
|
async def get_events(
|
|
self,
|
|
task_id: str,
|
|
after_sequence: int = 0,
|
|
limit: int = 100,
|
|
) -> List[Dict]:
|
|
"""获取事件列表"""
|
|
if not self.db_session_factory:
|
|
return []
|
|
|
|
from sqlalchemy.future import select
|
|
from app.models.agent_task import AgentEvent
|
|
|
|
async with self.db_session_factory() as db:
|
|
result = await db.execute(
|
|
select(AgentEvent)
|
|
.where(AgentEvent.task_id == task_id)
|
|
.where(AgentEvent.sequence > after_sequence)
|
|
.order_by(AgentEvent.sequence)
|
|
.limit(limit)
|
|
)
|
|
events = result.scalars().all()
|
|
return [event.to_sse_dict() for event in events]
|
|
|
|
async def stream_events(
|
|
self,
|
|
task_id: str,
|
|
after_sequence: int = 0,
|
|
) -> AsyncGenerator[Dict, None]:
|
|
"""流式获取事件"""
|
|
queue = self.create_queue(task_id)
|
|
|
|
# 先发送历史事件
|
|
history = await self.get_events(task_id, after_sequence)
|
|
for event in history:
|
|
yield event
|
|
|
|
# 然后实时推送新事件
|
|
try:
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=30)
|
|
yield event
|
|
|
|
# 检查是否是结束事件
|
|
if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]:
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
# 发送心跳
|
|
yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()}
|
|
|
|
finally:
|
|
self.remove_queue(task_id)
|
|
|
|
def create_emitter(self, task_id: str) -> AgentEventEmitter:
|
|
"""创建事件发射器"""
|
|
return AgentEventEmitter(task_id, self)
|
|
|