1221 lines
42 KiB
Python
1221 lines
42 KiB
Python
"""
|
||
Agent 基类
|
||
定义 Agent 的基本接口和通用功能
|
||
|
||
核心原则:
|
||
1. LLM 是 Agent 的大脑,全程参与决策
|
||
2. Agent 之间通过 TaskHandoff 传递结构化上下文
|
||
3. 事件分为流式事件(前端展示)和持久化事件(数据库记录)
|
||
4. 支持动态Agent树和专业知识模块
|
||
5. 完整的状态管理和Agent间通信
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from datetime import datetime, timezone
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import uuid
|
||
|
||
from ..core.state import AgentState, AgentStatus
|
||
from ..core.registry import agent_registry
|
||
from ..core.message import message_bus, MessageType, AgentMessage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentType(Enum):
|
||
"""Agent 类型"""
|
||
ORCHESTRATOR = "orchestrator"
|
||
RECON = "recon"
|
||
ANALYSIS = "analysis"
|
||
VERIFICATION = "verification"
|
||
|
||
|
||
class AgentPattern(Enum):
|
||
"""Agent 运行模式"""
|
||
REACT = "react" # 反应式:思考-行动-观察循环
|
||
PLAN_AND_EXECUTE = "plan_execute" # 计划执行:先规划后执行
|
||
|
||
|
||
@dataclass
|
||
class AgentConfig:
|
||
"""Agent 配置"""
|
||
name: str
|
||
agent_type: AgentType
|
||
pattern: AgentPattern = AgentPattern.REACT
|
||
|
||
# LLM 配置
|
||
model: Optional[str] = None
|
||
temperature: float = 0.1
|
||
max_tokens: int = 8192
|
||
|
||
# 执行限制
|
||
max_iterations: int = 20
|
||
timeout_seconds: int = 600
|
||
|
||
# 工具配置
|
||
tools: List[str] = field(default_factory=list)
|
||
|
||
# 系统提示词
|
||
system_prompt: Optional[str] = None
|
||
|
||
# 元数据
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class AgentResult:
|
||
"""Agent 执行结果"""
|
||
success: bool
|
||
data: Any = None
|
||
error: Optional[str] = None
|
||
|
||
# 执行统计
|
||
iterations: int = 0
|
||
tool_calls: int = 0
|
||
tokens_used: int = 0
|
||
duration_ms: int = 0
|
||
|
||
# 中间结果
|
||
intermediate_steps: List[Dict[str, Any]] = field(default_factory=list)
|
||
|
||
# 元数据
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
# 🔥 协作信息 - Agent 传递给下一个 Agent 的结构化信息
|
||
handoff: Optional["TaskHandoff"] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"success": self.success,
|
||
"data": self.data,
|
||
"error": self.error,
|
||
"iterations": self.iterations,
|
||
"tool_calls": self.tool_calls,
|
||
"tokens_used": self.tokens_used,
|
||
"duration_ms": self.duration_ms,
|
||
"metadata": self.metadata,
|
||
"handoff": self.handoff.to_dict() if self.handoff else None,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class TaskHandoff:
|
||
"""
|
||
任务交接协议 - Agent 之间传递的结构化信息
|
||
|
||
设计原则:
|
||
1. 包含足够的上下文让下一个 Agent 理解前序工作
|
||
2. 提供明确的建议和关注点
|
||
3. 可直接转换为 LLM 可理解的 prompt
|
||
"""
|
||
# 基本信息
|
||
from_agent: str
|
||
to_agent: str
|
||
|
||
# 工作摘要
|
||
summary: str
|
||
work_completed: List[str] = field(default_factory=list)
|
||
|
||
# 关键发现和洞察
|
||
key_findings: List[Dict[str, Any]] = field(default_factory=list)
|
||
insights: List[str] = field(default_factory=list)
|
||
|
||
# 建议和关注点
|
||
suggested_actions: List[Dict[str, Any]] = field(default_factory=list)
|
||
attention_points: List[str] = field(default_factory=list)
|
||
priority_areas: List[str] = field(default_factory=list)
|
||
|
||
# 上下文数据
|
||
context_data: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
# 置信度
|
||
confidence: float = 0.8
|
||
|
||
# 时间戳
|
||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"from_agent": self.from_agent,
|
||
"to_agent": self.to_agent,
|
||
"summary": self.summary,
|
||
"work_completed": self.work_completed,
|
||
"key_findings": self.key_findings,
|
||
"insights": self.insights,
|
||
"suggested_actions": self.suggested_actions,
|
||
"attention_points": self.attention_points,
|
||
"priority_areas": self.priority_areas,
|
||
"context_data": self.context_data,
|
||
"confidence": self.confidence,
|
||
"timestamp": self.timestamp.isoformat(),
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> "TaskHandoff":
|
||
return cls(
|
||
from_agent=data.get("from_agent", ""),
|
||
to_agent=data.get("to_agent", ""),
|
||
summary=data.get("summary", ""),
|
||
work_completed=data.get("work_completed", []),
|
||
key_findings=data.get("key_findings", []),
|
||
insights=data.get("insights", []),
|
||
suggested_actions=data.get("suggested_actions", []),
|
||
attention_points=data.get("attention_points", []),
|
||
priority_areas=data.get("priority_areas", []),
|
||
context_data=data.get("context_data", {}),
|
||
confidence=data.get("confidence", 0.8),
|
||
)
|
||
|
||
def to_prompt_context(self) -> str:
|
||
"""
|
||
转换为 LLM 可理解的上下文格式
|
||
这是关键!让 LLM 能够理解前序 Agent 的工作
|
||
"""
|
||
lines = [
|
||
f"## 来自 {self.from_agent} Agent 的任务交接",
|
||
"",
|
||
f"### 工作摘要",
|
||
self.summary,
|
||
"",
|
||
]
|
||
|
||
if self.work_completed:
|
||
lines.append("### 已完成的工作")
|
||
for work in self.work_completed:
|
||
lines.append(f"- {work}")
|
||
lines.append("")
|
||
|
||
if self.key_findings:
|
||
lines.append("### 关键发现")
|
||
for i, finding in enumerate(self.key_findings[:15], 1):
|
||
severity = finding.get("severity", "medium")
|
||
title = finding.get("title", "Unknown")
|
||
file_path = finding.get("file_path", "")
|
||
lines.append(f"{i}. [{severity.upper()}] {title}")
|
||
if file_path:
|
||
lines.append(f" 位置: {file_path}:{finding.get('line_start', '')}")
|
||
if finding.get("description"):
|
||
lines.append(f" 描述: {finding['description'][:100]}")
|
||
lines.append("")
|
||
|
||
if self.insights:
|
||
lines.append("### 洞察和分析")
|
||
for insight in self.insights:
|
||
lines.append(f"- {insight}")
|
||
lines.append("")
|
||
|
||
if self.suggested_actions:
|
||
lines.append("### 建议的下一步行动")
|
||
for action in self.suggested_actions:
|
||
action_type = action.get("type", "general")
|
||
description = action.get("description", "")
|
||
priority = action.get("priority", "medium")
|
||
lines.append(f"- [{priority.upper()}] {action_type}: {description}")
|
||
lines.append("")
|
||
|
||
if self.attention_points:
|
||
lines.append("### ⚠️ 需要特别关注")
|
||
for point in self.attention_points:
|
||
lines.append(f"- {point}")
|
||
lines.append("")
|
||
|
||
if self.priority_areas:
|
||
lines.append("### 优先分析区域")
|
||
for area in self.priority_areas:
|
||
lines.append(f"- {area}")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
class BaseAgent(ABC):
|
||
"""
|
||
Agent 基类
|
||
|
||
核心原则:
|
||
1. LLM 是 Agent 的大脑,全程参与决策
|
||
2. 所有日志应该反映 LLM 的思考过程
|
||
3. 工具调用是 LLM 的决策结果
|
||
|
||
协作原则:
|
||
1. 通过 TaskHandoff 接收前序 Agent 的上下文
|
||
2. 执行完成后生成 TaskHandoff 传递给下一个 Agent
|
||
3. 洞察和发现应该结构化记录
|
||
|
||
动态Agent树:
|
||
1. 支持动态创建子Agent
|
||
2. Agent间通过消息总线通信
|
||
3. 完整的状态管理和生命周期
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: AgentConfig,
|
||
llm_service,
|
||
tools: Dict[str, Any],
|
||
event_emitter=None,
|
||
parent_id: Optional[str] = None,
|
||
knowledge_modules: Optional[List[str]] = None,
|
||
):
|
||
"""
|
||
初始化 Agent
|
||
|
||
Args:
|
||
config: Agent 配置
|
||
llm_service: LLM 服务
|
||
tools: 可用工具字典
|
||
event_emitter: 事件发射器
|
||
parent_id: 父Agent ID(用于动态Agent树)
|
||
knowledge_modules: 要加载的知识模块
|
||
"""
|
||
self.config = config
|
||
self.llm_service = llm_service
|
||
self.tools = tools
|
||
self.event_emitter = event_emitter
|
||
self.parent_id = parent_id
|
||
self.knowledge_modules = knowledge_modules or []
|
||
|
||
# 🔥 生成唯一ID
|
||
self._agent_id = f"agent_{uuid.uuid4().hex[:8]}"
|
||
|
||
# 🔥 增强的状态管理
|
||
self._state = AgentState(
|
||
agent_id=self._agent_id,
|
||
agent_name=config.name,
|
||
agent_type=config.agent_type.value,
|
||
parent_id=parent_id,
|
||
max_iterations=config.max_iterations,
|
||
knowledge_modules=self.knowledge_modules,
|
||
)
|
||
|
||
# 运行状态(保持向后兼容)
|
||
self._iteration = 0
|
||
self._total_tokens = 0
|
||
self._tool_calls = 0
|
||
self._cancelled = False
|
||
|
||
# 🔥 协作状态
|
||
self._incoming_handoff: Optional[TaskHandoff] = None
|
||
self._insights: List[str] = [] # 收集的洞察
|
||
self._work_completed: List[str] = [] # 完成的工作记录
|
||
|
||
# 🔥 是否已注册到注册表
|
||
self._registered = False
|
||
|
||
# 🔥 加载知识模块到系统提示词
|
||
if self.knowledge_modules:
|
||
self._load_knowledge_modules()
|
||
|
||
def _register_to_registry(self, task: Optional[str] = None) -> None:
|
||
"""注册到Agent注册表(延迟注册,在run时调用)"""
|
||
logger.debug(f"[AgentTree] _register_to_registry 被调用: {self.config.name} (id={self._agent_id}, parent={self.parent_id}, _registered={self._registered})")
|
||
|
||
if self._registered:
|
||
logger.debug(f"[AgentTree] {self.config.name} 已注册,跳过 (id={self._agent_id})")
|
||
return
|
||
|
||
logger.debug(f"[AgentTree] 正在注册 Agent: {self.config.name} (id={self._agent_id}, parent={self.parent_id})")
|
||
|
||
agent_registry.register_agent(
|
||
agent_id=self._agent_id,
|
||
agent_name=self.config.name,
|
||
agent_type=self.config.agent_type.value,
|
||
task=task or self._state.task or "Initializing",
|
||
parent_id=self.parent_id,
|
||
agent_instance=self,
|
||
state=self._state,
|
||
knowledge_modules=self.knowledge_modules,
|
||
)
|
||
|
||
# 创建消息队列
|
||
message_bus.create_queue(self._agent_id)
|
||
self._registered = True
|
||
|
||
tree = agent_registry.get_agent_tree()
|
||
logger.debug(f"[AgentTree] Agent 注册完成: {self.config.name}, 当前树节点数: {len(tree['nodes'])}")
|
||
|
||
def set_parent_id(self, parent_id: str) -> None:
|
||
"""设置父Agent ID(在调度时调用)"""
|
||
self.parent_id = parent_id
|
||
self._state.parent_id = parent_id
|
||
|
||
def _load_knowledge_modules(self) -> None:
|
||
"""加载知识模块到系统提示词"""
|
||
if not self.knowledge_modules:
|
||
return
|
||
|
||
try:
|
||
from ..knowledge import knowledge_loader
|
||
|
||
enhanced_prompt = knowledge_loader.build_system_prompt_with_modules(
|
||
self.config.system_prompt or "",
|
||
self.knowledge_modules,
|
||
)
|
||
self.config.system_prompt = enhanced_prompt
|
||
|
||
logger.info(f"[{self.name}] Loaded knowledge modules: {self.knowledge_modules}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to load knowledge modules: {e}")
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self.config.name
|
||
|
||
@property
|
||
def agent_id(self) -> str:
|
||
return self._agent_id
|
||
|
||
@property
|
||
def state(self) -> AgentState:
|
||
return self._state
|
||
|
||
@property
|
||
def agent_type(self) -> AgentType:
|
||
return self.config.agent_type
|
||
|
||
# ============ Agent间消息处理 ============
|
||
|
||
def check_messages(self) -> List[AgentMessage]:
|
||
"""
|
||
检查并处理收到的消息
|
||
|
||
Returns:
|
||
未读消息列表
|
||
"""
|
||
messages = message_bus.get_messages(
|
||
self._agent_id,
|
||
unread_only=True,
|
||
mark_as_read=True,
|
||
)
|
||
|
||
for msg in messages:
|
||
# 处理消息
|
||
if msg.from_agent == "user":
|
||
# 用户消息直接添加到对话历史
|
||
self._state.add_message("user", msg.content)
|
||
else:
|
||
# Agent间消息使用XML格式
|
||
self._state.add_message("user", msg.to_xml())
|
||
|
||
# 如果在等待状态,恢复执行
|
||
if self._state.is_waiting_for_input():
|
||
self._state.resume_from_waiting()
|
||
agent_registry.update_agent_status(self._agent_id, "running")
|
||
|
||
return messages
|
||
|
||
def has_pending_messages(self) -> bool:
|
||
"""检查是否有待处理的消息"""
|
||
return message_bus.has_unread_messages(self._agent_id)
|
||
|
||
def send_message_to_parent(
|
||
self,
|
||
content: str,
|
||
message_type: MessageType = MessageType.INFORMATION,
|
||
) -> None:
|
||
"""向父Agent发送消息"""
|
||
if self.parent_id:
|
||
message_bus.send_message(
|
||
from_agent=self._agent_id,
|
||
to_agent=self.parent_id,
|
||
content=content,
|
||
message_type=message_type,
|
||
)
|
||
|
||
def send_message_to_agent(
|
||
self,
|
||
target_id: str,
|
||
content: str,
|
||
message_type: MessageType = MessageType.INFORMATION,
|
||
) -> None:
|
||
"""向指定Agent发送消息"""
|
||
message_bus.send_message(
|
||
from_agent=self._agent_id,
|
||
to_agent=target_id,
|
||
content=content,
|
||
message_type=message_type,
|
||
)
|
||
|
||
# ============ 生命周期管理 ============
|
||
|
||
def on_start(self) -> None:
|
||
"""Agent开始执行时调用"""
|
||
self._state.start()
|
||
agent_registry.update_agent_status(self._agent_id, "running")
|
||
|
||
def on_complete(self, result: Dict[str, Any]) -> None:
|
||
"""Agent完成时调用"""
|
||
self._state.set_completed(result)
|
||
agent_registry.update_agent_status(self._agent_id, "completed", result)
|
||
|
||
# 向父Agent报告完成
|
||
if self.parent_id:
|
||
message_bus.send_completion_report(
|
||
from_agent=self._agent_id,
|
||
to_agent=self.parent_id,
|
||
summary=result.get("summary", "Task completed"),
|
||
findings=result.get("findings", []),
|
||
success=True,
|
||
)
|
||
|
||
def on_error(self, error: str) -> None:
|
||
"""Agent出错时调用"""
|
||
self._state.set_failed(error)
|
||
agent_registry.update_agent_status(self._agent_id, "failed", {"error": error})
|
||
|
||
@abstractmethod
|
||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||
"""
|
||
执行 Agent 任务
|
||
|
||
Args:
|
||
input_data: 输入数据
|
||
|
||
Returns:
|
||
Agent 执行结果
|
||
"""
|
||
pass
|
||
|
||
def cancel(self):
|
||
"""取消执行"""
|
||
self._cancelled = True
|
||
logger.info(f"[{self.name}] Cancel requested")
|
||
|
||
# 🔥 外部取消检查回调
|
||
self._cancel_callback = None
|
||
|
||
def set_cancel_callback(self, callback) -> None:
|
||
"""设置外部取消检查回调"""
|
||
self._cancel_callback = callback
|
||
|
||
@property
|
||
def is_cancelled(self) -> bool:
|
||
"""检查是否已取消(包含内部标志和外部回调)"""
|
||
if self._cancelled:
|
||
return True
|
||
# 检查外部回调
|
||
if self._cancel_callback and self._cancel_callback():
|
||
self._cancelled = True
|
||
logger.info(f"[{self.name}] Detected cancellation from callback")
|
||
return True
|
||
return False
|
||
|
||
# ============ 协作方法 ============
|
||
|
||
def receive_handoff(self, handoff: TaskHandoff):
|
||
"""
|
||
接收来自前序 Agent 的任务交接
|
||
|
||
Args:
|
||
handoff: 任务交接对象
|
||
"""
|
||
self._incoming_handoff = handoff
|
||
logger.info(
|
||
f"[{self.name}] Received handoff from {handoff.from_agent}: "
|
||
f"{handoff.summary[:50]}..."
|
||
)
|
||
|
||
def get_handoff_context(self) -> str:
|
||
"""
|
||
获取交接上下文(用于构建 LLM prompt)
|
||
|
||
Returns:
|
||
格式化的上下文字符串
|
||
"""
|
||
if not self._incoming_handoff:
|
||
return ""
|
||
return self._incoming_handoff.to_prompt_context()
|
||
|
||
def add_insight(self, insight: str):
|
||
"""记录洞察"""
|
||
self._insights.append(insight)
|
||
|
||
def record_work(self, work: str):
|
||
"""记录完成的工作"""
|
||
self._work_completed.append(work)
|
||
|
||
def create_handoff(
|
||
self,
|
||
to_agent: str,
|
||
summary: str,
|
||
key_findings: List[Dict[str, Any]] = None,
|
||
suggested_actions: List[Dict[str, Any]] = None,
|
||
attention_points: List[str] = None,
|
||
priority_areas: List[str] = None,
|
||
context_data: Dict[str, Any] = None,
|
||
) -> TaskHandoff:
|
||
"""
|
||
创建任务交接
|
||
|
||
Args:
|
||
to_agent: 目标 Agent
|
||
summary: 工作摘要
|
||
key_findings: 关键发现
|
||
suggested_actions: 建议的行动
|
||
attention_points: 需要关注的点
|
||
priority_areas: 优先分析区域
|
||
context_data: 上下文数据
|
||
|
||
Returns:
|
||
TaskHandoff 对象
|
||
"""
|
||
return TaskHandoff(
|
||
from_agent=self.name,
|
||
to_agent=to_agent,
|
||
summary=summary,
|
||
work_completed=self._work_completed.copy(),
|
||
key_findings=key_findings or [],
|
||
insights=self._insights.copy(),
|
||
suggested_actions=suggested_actions or [],
|
||
attention_points=attention_points or [],
|
||
priority_areas=priority_areas or [],
|
||
context_data=context_data or {},
|
||
)
|
||
|
||
def build_prompt_with_handoff(self, base_prompt: str) -> str:
|
||
"""
|
||
构建包含交接上下文的 prompt
|
||
|
||
Args:
|
||
base_prompt: 基础 prompt
|
||
|
||
Returns:
|
||
增强后的 prompt
|
||
"""
|
||
handoff_context = self.get_handoff_context()
|
||
if not handoff_context:
|
||
return base_prompt
|
||
|
||
return f"""{base_prompt}
|
||
|
||
---
|
||
## 前序 Agent 交接信息
|
||
|
||
{handoff_context}
|
||
|
||
---
|
||
请基于以上来自前序 Agent 的信息,结合你的专业能力开展工作。
|
||
"""
|
||
|
||
# ============ 核心事件发射方法 ============
|
||
|
||
async def emit_event(
|
||
self,
|
||
event_type: str,
|
||
message: str,
|
||
**kwargs
|
||
):
|
||
"""发射事件"""
|
||
if self.event_emitter:
|
||
from ..event_manager import AgentEventData
|
||
|
||
# 准备 metadata
|
||
metadata = kwargs.get("metadata", {}) or {}
|
||
if "agent_name" not in metadata:
|
||
metadata["agent_name"] = self.name
|
||
|
||
# 分离已知字段和未知字段
|
||
known_fields = {
|
||
"phase", "tool_name", "tool_input", "tool_output",
|
||
"tool_duration_ms", "finding_id", "tokens_used"
|
||
}
|
||
|
||
event_kwargs = {}
|
||
for k, v in kwargs.items():
|
||
if k in known_fields:
|
||
event_kwargs[k] = v
|
||
elif k != "metadata":
|
||
# 将未知字段放入 metadata
|
||
metadata[k] = v
|
||
|
||
await self.event_emitter.emit(AgentEventData(
|
||
event_type=event_type,
|
||
message=message,
|
||
metadata=metadata,
|
||
**event_kwargs
|
||
))
|
||
|
||
# ============ LLM 思考相关事件 ============
|
||
|
||
async def emit_thinking(self, message: str):
|
||
"""发射 LLM 思考事件"""
|
||
await self.emit_event("thinking", message)
|
||
|
||
async def emit_llm_start(self, iteration: int):
|
||
"""发射 LLM 开始思考事件"""
|
||
await self.emit_event(
|
||
"llm_start",
|
||
f"[{self.name}] 第 {iteration} 轮迭代开始",
|
||
metadata={"iteration": iteration}
|
||
)
|
||
|
||
async def emit_llm_thought(self, thought: str, iteration: int):
|
||
"""发射 LLM 思考内容事件 - 这是核心!展示 LLM 在想什么"""
|
||
# 截断过长的思考内容
|
||
display_thought = thought[:500] + "..." if len(thought) > 500 else thought
|
||
await self.emit_event(
|
||
"llm_thought",
|
||
f"[{self.name}] 思考: {display_thought}",
|
||
metadata={
|
||
"thought": thought,
|
||
"iteration": iteration,
|
||
}
|
||
)
|
||
|
||
async def emit_thinking_start(self):
|
||
"""发射开始思考事件(流式输出用)"""
|
||
await self.emit_event("thinking_start", "开始思考...")
|
||
|
||
async def emit_thinking_token(self, token: str, accumulated: str):
|
||
"""发射思考 token 事件(流式输出用)"""
|
||
await self.emit_event(
|
||
"thinking_token",
|
||
"", # 不需要 message,前端从 metadata 获取
|
||
metadata={
|
||
"token": token,
|
||
"accumulated": accumulated,
|
||
}
|
||
)
|
||
|
||
async def emit_thinking_end(self, full_response: str):
|
||
"""发射思考结束事件(流式输出用)"""
|
||
await self.emit_event(
|
||
"thinking_end",
|
||
"思考完成",
|
||
metadata={"accumulated": full_response}
|
||
)
|
||
|
||
async def emit_llm_decision(self, decision: str, reason: str = ""):
|
||
"""发射 LLM 决策事件 - 展示 LLM 做了什么决定"""
|
||
await self.emit_event(
|
||
"llm_decision",
|
||
f"[{self.name}] 决策: {decision}" + (f" ({reason})" if reason else ""),
|
||
metadata={
|
||
"decision": decision,
|
||
"reason": reason,
|
||
}
|
||
)
|
||
|
||
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
|
||
"""发射 LLM 完成事件"""
|
||
await self.emit_event(
|
||
"llm_complete",
|
||
f"[{self.name}] 完成: {result_summary} (消耗 {tokens_used} tokens)",
|
||
metadata={
|
||
"tokens_used": tokens_used,
|
||
}
|
||
)
|
||
|
||
async def emit_llm_action(self, action: str, action_input: Dict):
|
||
"""发射 LLM 动作决策事件"""
|
||
await self.emit_event(
|
||
"llm_action",
|
||
f"[{self.name}] 执行动作: {action}",
|
||
metadata={
|
||
"action": action,
|
||
"action_input": action_input,
|
||
}
|
||
)
|
||
|
||
async def emit_llm_observation(self, observation: str):
|
||
"""发射 LLM 观察事件"""
|
||
# 截断过长的观察结果
|
||
display_obs = observation[:300] + "..." if len(observation) > 300 else observation
|
||
await self.emit_event(
|
||
"llm_observation",
|
||
f"[{self.name}] 观察结果: {display_obs}",
|
||
metadata={
|
||
"observation": observation[:2000], # 限制存储长度
|
||
}
|
||
)
|
||
|
||
# ============ 工具调用相关事件 ============
|
||
|
||
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
|
||
"""发射工具调用事件"""
|
||
await self.emit_event(
|
||
"tool_call",
|
||
f"[{self.name}] 调用工具: {tool_name}",
|
||
tool_name=tool_name,
|
||
tool_input=tool_input,
|
||
)
|
||
|
||
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
|
||
"""发射工具结果事件"""
|
||
# 🔥 修复:确保 result 不为 None,避免显示 "None" 字符串
|
||
safe_result = result if result and result != "None" else ""
|
||
tool_output_dict = {"result": safe_result[:2000] if safe_result else ""} # 截断长输出
|
||
await self.emit_event(
|
||
"tool_result",
|
||
f"[{self.name}] 工具 {tool_name} 完成 ({duration_ms}ms)",
|
||
tool_name=tool_name,
|
||
tool_output=tool_output_dict,
|
||
tool_duration_ms=duration_ms,
|
||
)
|
||
|
||
# ============ 发现相关事件 ============
|
||
|
||
async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = "", is_verified: bool = False):
|
||
"""发射漏洞发现事件"""
|
||
import uuid
|
||
finding_id = str(uuid.uuid4())
|
||
|
||
# 🔥 使用 EventManager.emit_finding 发送正确的事件类型
|
||
if self.event_emitter and hasattr(self.event_emitter, 'emit_finding'):
|
||
await self.event_emitter.emit_finding(
|
||
finding_id=finding_id,
|
||
title=title,
|
||
severity=severity,
|
||
vulnerability_type=vuln_type,
|
||
is_verified=is_verified,
|
||
)
|
||
else:
|
||
# 回退:使用通用事件发射
|
||
severity_emoji = {
|
||
"critical": "🔴",
|
||
"high": "🟠",
|
||
"medium": "🟡",
|
||
"low": "🟢",
|
||
}.get(severity.lower(), "⚪")
|
||
|
||
event_type = "finding_verified" if is_verified else "finding_new"
|
||
await self.emit_event(
|
||
event_type,
|
||
f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}",
|
||
metadata={
|
||
"id": finding_id,
|
||
"title": title,
|
||
"severity": severity,
|
||
"vulnerability_type": vuln_type,
|
||
"file_path": file_path,
|
||
"is_verified": is_verified,
|
||
}
|
||
)
|
||
|
||
# ============ 通用工具方法 ============
|
||
|
||
async def call_tool(self, tool_name: str, **kwargs) -> Any:
|
||
"""
|
||
调用工具
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
**kwargs: 工具参数
|
||
|
||
Returns:
|
||
工具执行结果
|
||
"""
|
||
tool = self.tools.get(tool_name)
|
||
if not tool:
|
||
logger.warning(f"Tool not found: {tool_name}")
|
||
return None
|
||
|
||
self._tool_calls += 1
|
||
await self.emit_tool_call(tool_name, kwargs)
|
||
|
||
import time
|
||
start = time.time()
|
||
|
||
result = await tool.execute(**kwargs)
|
||
|
||
duration_ms = int((time.time() - start) * 1000)
|
||
await self.emit_tool_result(tool_name, str(result.data)[:500], duration_ms)
|
||
|
||
return result
|
||
|
||
async def call_llm(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
tools: Optional[List[Dict]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用 LLM
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
tools: 可用工具描述
|
||
|
||
Returns:
|
||
LLM 响应
|
||
"""
|
||
self._iteration += 1
|
||
|
||
try:
|
||
response = await self.llm_service.chat_completion(
|
||
messages=messages,
|
||
temperature=self.config.temperature,
|
||
max_tokens=self.config.max_tokens,
|
||
tools=tools,
|
||
)
|
||
|
||
if response.get("usage"):
|
||
self._total_tokens += response["usage"].get("total_tokens", 0)
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM call failed: {e}")
|
||
raise
|
||
|
||
def get_tool_descriptions(self) -> List[Dict[str, Any]]:
|
||
"""获取工具描述(用于 LLM)"""
|
||
descriptions = []
|
||
|
||
for name, tool in self.tools.items():
|
||
if name.startswith("_"):
|
||
continue
|
||
|
||
desc = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"description": tool.description,
|
||
}
|
||
}
|
||
|
||
# 添加参数 schema
|
||
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||
desc["function"]["parameters"] = tool.args_schema.schema()
|
||
|
||
descriptions.append(desc)
|
||
|
||
return descriptions
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取执行统计"""
|
||
return {
|
||
"agent": self.name,
|
||
"type": self.agent_type.value,
|
||
"iterations": self._iteration,
|
||
"tool_calls": self._tool_calls,
|
||
"tokens_used": self._total_tokens,
|
||
}
|
||
|
||
# ============ Memory Compression ============
|
||
|
||
def compress_messages_if_needed(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
max_tokens: int = 100000,
|
||
) -> List[Dict[str, str]]:
|
||
"""
|
||
如果消息历史过长,自动压缩
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
max_tokens: 最大token数
|
||
|
||
Returns:
|
||
压缩后的消息列表
|
||
"""
|
||
from ...llm.memory_compressor import MemoryCompressor
|
||
|
||
compressor = MemoryCompressor(max_total_tokens=max_tokens)
|
||
|
||
if compressor.should_compress(messages):
|
||
logger.info(f"[{self.name}] Compressing conversation history...")
|
||
compressed = compressor.compress_history(messages)
|
||
logger.info(f"[{self.name}] Compressed {len(messages)} -> {len(compressed)} messages")
|
||
return compressed
|
||
|
||
return messages
|
||
|
||
# ============ 统一的流式 LLM 调用 ============
|
||
|
||
async def stream_llm_call(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.1,
|
||
max_tokens: int = 2048,
|
||
auto_compress: bool = True,
|
||
) -> Tuple[str, int]:
|
||
"""
|
||
统一的流式 LLM 调用方法
|
||
|
||
所有 Agent 共用此方法,避免重复代码
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度
|
||
max_tokens: 最大 token 数
|
||
auto_compress: 是否自动压缩过长的消息历史
|
||
|
||
Returns:
|
||
(完整响应内容, token数量)
|
||
"""
|
||
# 🔥 自动压缩过长的消息历史
|
||
if auto_compress:
|
||
messages = self.compress_messages_if_needed(messages)
|
||
|
||
accumulated = ""
|
||
total_tokens = 0
|
||
|
||
# 🔥 在开始 LLM 调用前检查取消
|
||
if self.is_cancelled:
|
||
logger.info(f"[{self.name}] Cancelled before LLM call")
|
||
return "", 0
|
||
|
||
logger.info(f"[{self.name}] 🚀 Starting stream_llm_call, emitting thinking_start...")
|
||
await self.emit_thinking_start()
|
||
logger.info(f"[{self.name}] ✅ thinking_start emitted, starting LLM stream...")
|
||
|
||
try:
|
||
# 获取流式迭代器
|
||
stream = self.llm_service.chat_completion_stream(
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
)
|
||
# 兼容不同版本的 python async generator
|
||
iterator = stream.__aiter__()
|
||
|
||
import time
|
||
first_token_received = False
|
||
last_activity = time.time()
|
||
|
||
while True:
|
||
# 检查取消
|
||
if self.is_cancelled:
|
||
logger.info(f"[{self.name}] Cancelled during LLM streaming loop")
|
||
break
|
||
|
||
try:
|
||
# 🔥 第一個 token 30秒超时,后续 token 60秒超时
|
||
# 这是一个应用层的安全网,防止底层 LLM 客户端挂死
|
||
timeout = 30.0 if not first_token_received else 60.0
|
||
|
||
chunk = await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
|
||
|
||
last_activity = time.time()
|
||
|
||
if chunk["type"] == "token":
|
||
first_token_received = True
|
||
token = chunk["content"]
|
||
# 🔥 累积 content,确保 accumulated 变量更新
|
||
# 注意:某些 adapter 返回的 chunk["accumulated"] 可能已经包含了累积值,
|
||
# 但为了安全起见,如果不一致,我们自己累积
|
||
if "accumulated" in chunk:
|
||
accumulated = chunk["accumulated"]
|
||
else:
|
||
# 如果 adapter 没返回 accumulated,我们自己拼
|
||
# 注意:如果是 token 类型,content 是增量
|
||
# 如果 accumulated 被覆盖了,需要小心。
|
||
# 实际上 service.py 中 chat_completion_stream 保证了 accumulated 存在
|
||
# 这里我们信任 service 层的 accumulated
|
||
pass
|
||
|
||
# Double check if accumulated is empty but we have token
|
||
if not accumulated and token:
|
||
accumulated += token # Fallback
|
||
|
||
await self.emit_thinking_token(token, accumulated)
|
||
# 🔥 CRITICAL: 让出控制权给事件循环,让 SSE 有机会发送事件
|
||
await asyncio.sleep(0)
|
||
|
||
elif chunk["type"] == "done":
|
||
accumulated = chunk["content"]
|
||
if chunk.get("usage"):
|
||
total_tokens = chunk["usage"].get("total_tokens", 0)
|
||
break
|
||
|
||
elif chunk["type"] == "error":
|
||
accumulated = chunk.get("accumulated", "")
|
||
error_msg = chunk.get("error", "Unknown error")
|
||
logger.error(f"[{self.name}] Stream error: {error_msg}")
|
||
if accumulated:
|
||
total_tokens = chunk.get("usage", {}).get("total_tokens", 0)
|
||
else:
|
||
accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。"
|
||
break
|
||
|
||
except StopAsyncIteration:
|
||
break
|
||
except asyncio.TimeoutError:
|
||
timeout_type = "First Token" if not first_token_received else "Stream"
|
||
logger.error(f"[{self.name}] LLM {timeout_type} Timeout ({timeout}s)")
|
||
error_msg = f"LLM 响应超时 ({timeout_type}, {timeout}s)"
|
||
await self.emit_event("error", error_msg)
|
||
if not accumulated:
|
||
accumulated = f"[超时错误: {timeout}s 无响应] 请尝试简化请求或重试。"
|
||
break
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"[{self.name}] LLM call cancelled")
|
||
raise
|
||
except Exception as e:
|
||
# 🔥 增强异常处理,避免吞掉错误
|
||
logger.error(f"[{self.name}] Unexpected error in stream_llm_call: {e}", exc_info=True)
|
||
await self.emit_event("error", f"LLM 调用错误: {str(e)}")
|
||
accumulated = f"[LLM调用错误: {str(e)}] 请重试。"
|
||
finally:
|
||
await self.emit_thinking_end(accumulated)
|
||
|
||
# 🔥 记录空响应警告,帮助调试
|
||
if not accumulated or not accumulated.strip():
|
||
logger.warning(f"[{self.name}] Empty LLM response returned (total_tokens: {total_tokens})")
|
||
|
||
return accumulated, total_tokens
|
||
|
||
async def execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||
"""
|
||
统一的工具执行方法 - 支持取消和超时
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_input: 工具参数
|
||
|
||
Returns:
|
||
工具执行结果字符串
|
||
"""
|
||
# 🔥 在执行工具前检查取消
|
||
if self.is_cancelled:
|
||
return "⚠️ 任务已取消"
|
||
|
||
tool = self.tools.get(tool_name)
|
||
|
||
if not tool:
|
||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||
|
||
try:
|
||
self._tool_calls += 1
|
||
await self.emit_tool_call(tool_name, tool_input)
|
||
|
||
import time
|
||
start = time.time()
|
||
|
||
# 🔥 根据工具类型设置不同的超时时间
|
||
tool_timeouts = {
|
||
"semgrep_scan": 120, # 外部扫描工具需要更长时间
|
||
"bandit_scan": 90,
|
||
"gitleaks_scan": 60,
|
||
"npm_audit": 90,
|
||
"safety_scan": 60,
|
||
"kunlun_scan": 180,
|
||
"osv_scanner": 60,
|
||
"trufflehog_scan": 90,
|
||
"sandbox_exec": 60,
|
||
"php_test": 30,
|
||
"command_injection_test": 30,
|
||
"sql_injection_test": 30,
|
||
"xss_test": 30,
|
||
}
|
||
timeout = tool_timeouts.get(tool_name, 30) # 默认30秒
|
||
|
||
# 🔥 使用 asyncio.wait_for 添加超时控制,同时支持取消
|
||
async def execute_with_cancel_check():
|
||
"""包装工具执行,定期检查取消状态"""
|
||
# 创建工具执行任务
|
||
execute_task = asyncio.create_task(tool.execute(**tool_input))
|
||
|
||
try:
|
||
# 使用循环定期检查取消状态
|
||
while not execute_task.done():
|
||
if self.is_cancelled:
|
||
execute_task.cancel()
|
||
try:
|
||
await execute_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
raise asyncio.CancelledError("任务已取消")
|
||
|
||
# 等待任务完成或超时检查间隔
|
||
try:
|
||
return await asyncio.wait_for(
|
||
asyncio.shield(execute_task),
|
||
timeout=0.5 # 每0.5秒检查一次取消状态
|
||
)
|
||
except asyncio.TimeoutError:
|
||
continue # 继续循环检查
|
||
|
||
return await execute_task
|
||
except asyncio.CancelledError:
|
||
if not execute_task.done():
|
||
execute_task.cancel()
|
||
raise
|
||
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
execute_with_cancel_check(),
|
||
timeout=timeout
|
||
)
|
||
except asyncio.TimeoutError:
|
||
duration_ms = int((time.time() - start) * 1000)
|
||
await self.emit_tool_result(tool_name, f"超时 ({timeout}s)", duration_ms)
|
||
return f"⚠️ 工具 '{tool_name}' 执行超时 ({timeout}秒),请尝试其他方法或减小操作范围。"
|
||
except asyncio.CancelledError:
|
||
duration_ms = int((time.time() - start) * 1000)
|
||
await self.emit_tool_result(tool_name, "已取消", duration_ms)
|
||
return "⚠️ 任务已取消"
|
||
|
||
duration_ms = int((time.time() - start) * 1000)
|
||
# 🔥 修复:确保传递有意义的结果字符串,避免 "None"
|
||
result_preview = str(result.data)[:200] if result.data is not None else (result.error[:200] if result.error else "")
|
||
await self.emit_tool_result(tool_name, result_preview, duration_ms)
|
||
|
||
# 🔥 工具执行后再次检查取消
|
||
if self.is_cancelled:
|
||
return "⚠️ 任务已取消"
|
||
|
||
if result.success:
|
||
output = str(result.data)
|
||
|
||
# 包含 metadata 中的额外信息
|
||
if result.metadata:
|
||
if "issues" in result.metadata:
|
||
output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}"
|
||
if "findings" in result.metadata:
|
||
output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}"
|
||
|
||
# 截断过长输出
|
||
if len(output) > 6000:
|
||
output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
|
||
return output
|
||
else:
|
||
# 🔥 输出详细的错误信息,包括原始错误
|
||
error_msg = f"""⚠️ 工具执行失败
|
||
|
||
**工具**: {tool_name}
|
||
**参数**: {json.dumps(tool_input, ensure_ascii=False, indent=2) if tool_input else '无'}
|
||
**错误**: {result.error}
|
||
|
||
请根据错误信息调整参数或尝试其他方法。"""
|
||
return error_msg
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"[{self.name}] Tool '{tool_name}' execution cancelled")
|
||
return "⚠️ 任务已取消"
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"Tool execution error: {e}")
|
||
# 🔥 输出完整的原始错误信息,包括堆栈跟踪
|
||
error_msg = f"""❌ 工具执行异常
|
||
|
||
**工具**: {tool_name}
|
||
**参数**: {json.dumps(tool_input, ensure_ascii=False, indent=2) if tool_input else '无'}
|
||
**错误类型**: {type(e).__name__}
|
||
**错误信息**: {str(e)}
|
||
**堆栈跟踪**:
|
||
```
|
||
{traceback.format_exc()}
|
||
```
|
||
|
||
请分析错误原因,可能需要:
|
||
1. 检查参数格式是否正确
|
||
2. 尝试使用其他工具
|
||
3. 如果是权限或资源问题,跳过该操作"""
|
||
return error_msg
|
||
|
||
def get_tools_description(self) -> str:
|
||
"""生成工具描述文本(用于 prompt)"""
|
||
tools_info = []
|
||
for name, tool in self.tools.items():
|
||
if name.startswith("_"):
|
||
continue
|
||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||
tools_info.append(desc)
|
||
return "\n".join(tools_info)
|