306 lines
10 KiB
Python
306 lines
10 KiB
Python
"""
|
||
扫描完成工具
|
||
|
||
用于主Agent结束安全审计任务,确保所有子Agent已完成。
|
||
"""
|
||
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from typing import Optional, List, Dict, Any
|
||
from pydantic import BaseModel, Field
|
||
|
||
from .base import AgentTool, ToolResult
|
||
from ..core.registry import agent_registry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class FinishScanInput(BaseModel):
|
||
"""扫描完成输入参数"""
|
||
content: str = Field(
|
||
...,
|
||
description="最终扫描报告内容,包含所有发现的漏洞总结"
|
||
)
|
||
success: bool = Field(
|
||
default=True,
|
||
description="扫描是否成功完成"
|
||
)
|
||
|
||
|
||
class FinishScanTool(AgentTool):
|
||
"""
|
||
扫描完成工具
|
||
|
||
只有根Agent(主Agent)才能使用此工具来正式结束安全审计任务。
|
||
|
||
使用前置条件:
|
||
1. 所有子Agent必须已完成(completed, failed, 或 stopped 状态)
|
||
2. 必须提供最终报告内容
|
||
|
||
使用约束:
|
||
- 子Agent必须使用 agent_finish 工具
|
||
- 根Agent必须使用此工具
|
||
"""
|
||
|
||
def __init__(self, agent_id: str, agent_state=None, tracer=None):
|
||
super().__init__()
|
||
self.agent_id = agent_id
|
||
self.agent_state = agent_state
|
||
self.tracer = tracer
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "finish_scan"
|
||
|
||
@property
|
||
def description(self) -> str:
|
||
return """完成整个安全扫描并生成最终报告。
|
||
|
||
只有根Agent(主编排Agent)才能使用此工具。
|
||
|
||
使用条件:
|
||
1. 所有子Agent必须已完成
|
||
2. 必须提供完整的扫描总结
|
||
|
||
参数:
|
||
- content: 最终扫描报告内容,包含:
|
||
- 扫描概述
|
||
- 发现的漏洞列表
|
||
- 风险评估
|
||
- 修复建议
|
||
- success: 扫描是否成功完成
|
||
|
||
不要使用此工具:
|
||
- 如果你是子Agent(请使用 agent_finish)
|
||
- 如果还有子Agent正在运行"""
|
||
|
||
@property
|
||
def args_schema(self):
|
||
return FinishScanInput
|
||
|
||
async def _execute(
|
||
self,
|
||
content: str,
|
||
success: bool = True,
|
||
**kwargs
|
||
) -> ToolResult:
|
||
"""执行扫描完成"""
|
||
|
||
# 验证是否为根Agent
|
||
validation_error = self._validate_root_agent()
|
||
if validation_error:
|
||
return validation_error
|
||
|
||
# 验证内容
|
||
if not content or not content.strip():
|
||
return ToolResult(success=False, error="报告内容不能为空")
|
||
|
||
# 检查是否有活跃的子Agent
|
||
active_check = self._check_active_agents()
|
||
if active_check:
|
||
return active_check
|
||
|
||
# 更新状态
|
||
final_result = {
|
||
"scan_completed": True,
|
||
"content": content.strip(),
|
||
"success": success,
|
||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
# 收集所有发现
|
||
all_findings = self._collect_all_findings()
|
||
final_result["total_findings"] = len(all_findings)
|
||
final_result["findings_summary"] = self._summarize_findings(all_findings)
|
||
|
||
# 更新Agent状态
|
||
if self.agent_state:
|
||
self.agent_state.set_completed(final_result)
|
||
|
||
agent_registry.update_agent_status(
|
||
self.agent_id,
|
||
"completed" if success else "failed",
|
||
final_result,
|
||
)
|
||
|
||
# 保存到追踪器
|
||
if self.tracer:
|
||
try:
|
||
self.tracer.set_final_scan_result(
|
||
content=content.strip(),
|
||
success=success,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to update tracer: {e}")
|
||
|
||
# 获取统计信息
|
||
stats = agent_registry.get_statistics()
|
||
|
||
return ToolResult(
|
||
success=True,
|
||
data={
|
||
"scan_completed": True,
|
||
"message": "扫描已成功完成" if success else "扫描完成但有错误",
|
||
"report_length": len(content),
|
||
"total_findings": len(all_findings),
|
||
"agent_statistics": stats,
|
||
},
|
||
metadata=final_result,
|
||
)
|
||
|
||
def _validate_root_agent(self) -> Optional[ToolResult]:
|
||
"""验证是否为根Agent"""
|
||
# 检查是否有父Agent
|
||
parent_id = agent_registry.get_parent(self.agent_id)
|
||
|
||
if parent_id is not None:
|
||
return ToolResult(
|
||
success=False,
|
||
error="此工具只能由根Agent使用。子Agent请使用 agent_finish 工具。"
|
||
)
|
||
|
||
# 检查是否为注册的根Agent
|
||
root_id = agent_registry.get_root_agent_id()
|
||
if root_id and root_id != self.agent_id:
|
||
return ToolResult(
|
||
success=False,
|
||
error=f"当前Agent不是根Agent。根Agent ID: {root_id}"
|
||
)
|
||
|
||
return None
|
||
|
||
def _check_active_agents(self) -> Optional[ToolResult]:
|
||
"""检查是否有活跃的子Agent"""
|
||
try:
|
||
tree = agent_registry.get_agent_tree()
|
||
|
||
running_agents = []
|
||
waiting_agents = []
|
||
stopping_agents = []
|
||
|
||
for agent_id, node in tree["nodes"].items():
|
||
# 跳过当前Agent
|
||
if agent_id == self.agent_id:
|
||
continue
|
||
|
||
status = node.get("status", "")
|
||
|
||
if status == "running":
|
||
running_agents.append({
|
||
"id": agent_id,
|
||
"name": node.get("name", "Unknown"),
|
||
"task": node.get("task", "No task description")[:80],
|
||
})
|
||
elif status == "waiting":
|
||
waiting_agents.append({
|
||
"id": agent_id,
|
||
"name": node.get("name", "Unknown"),
|
||
})
|
||
elif status == "stopping":
|
||
stopping_agents.append({
|
||
"id": agent_id,
|
||
"name": node.get("name", "Unknown"),
|
||
})
|
||
|
||
if running_agents or stopping_agents:
|
||
message_parts = ["无法完成扫描,还有活跃的子Agent:"]
|
||
|
||
if running_agents:
|
||
message_parts.append("\n\n运行中的Agent:")
|
||
for agent in running_agents:
|
||
message_parts.append(
|
||
f" - {agent['name']} ({agent['id']}): {agent['task']}"
|
||
)
|
||
|
||
if stopping_agents:
|
||
message_parts.append("\n\n正在停止的Agent:")
|
||
for agent in stopping_agents:
|
||
message_parts.append(f" - {agent['name']} ({agent['id']})")
|
||
|
||
message_parts.extend([
|
||
"\n\n建议操作:",
|
||
"1. 使用 wait_for_message 等待所有Agent完成",
|
||
"2. 使用 view_agent_graph 查看Agent状态",
|
||
"3. 如果需要紧急结束,发送消息要求Agent完成",
|
||
])
|
||
|
||
return ToolResult(
|
||
success=False,
|
||
error="\n".join(message_parts),
|
||
metadata={
|
||
"running_count": len(running_agents),
|
||
"stopping_count": len(stopping_agents),
|
||
"waiting_count": len(waiting_agents),
|
||
}
|
||
)
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to check active agents: {e}")
|
||
return None
|
||
|
||
def _collect_all_findings(self) -> List[Dict[str, Any]]:
|
||
"""收集所有子Agent的发现"""
|
||
all_findings = []
|
||
|
||
try:
|
||
tree = agent_registry.get_agent_tree()
|
||
|
||
for agent_id, node in tree["nodes"].items():
|
||
result = node.get("result")
|
||
if not result:
|
||
continue
|
||
|
||
# 从result中提取findings
|
||
if isinstance(result, dict):
|
||
findings = result.get("findings", [])
|
||
if isinstance(findings, list):
|
||
for finding in findings:
|
||
if isinstance(finding, dict):
|
||
finding["discovered_by"] = {
|
||
"agent_id": agent_id,
|
||
"agent_name": node.get("name", "Unknown"),
|
||
}
|
||
all_findings.append(finding)
|
||
elif isinstance(finding, str):
|
||
all_findings.append({
|
||
"description": finding,
|
||
"discovered_by": {
|
||
"agent_id": agent_id,
|
||
"agent_name": node.get("name", "Unknown"),
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"Failed to collect findings: {e}")
|
||
|
||
return all_findings
|
||
|
||
def _summarize_findings(self, findings: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""生成发现摘要"""
|
||
severity_counts = {
|
||
"critical": 0,
|
||
"high": 0,
|
||
"medium": 0,
|
||
"low": 0,
|
||
"info": 0,
|
||
}
|
||
|
||
type_counts = {}
|
||
|
||
for finding in findings:
|
||
# 统计严重性
|
||
severity = finding.get("severity", "medium").lower()
|
||
if severity in severity_counts:
|
||
severity_counts[severity] += 1
|
||
|
||
# 统计类型
|
||
vuln_type = finding.get("vulnerability_type", finding.get("type", "unknown"))
|
||
type_counts[vuln_type] = type_counts.get(vuln_type, 0) + 1
|
||
|
||
return {
|
||
"total": len(findings),
|
||
"by_severity": severity_counts,
|
||
"by_type": type_counts,
|
||
}
|