refactor(agent): 移除LangGraph工作流并迁移到动态Agent树架构
重构Agent服务架构,从基于LangGraph的状态图迁移到动态Agent树结构。主要变更包括: - 删除graph模块及相关测试 - 更新agent/__init__.py导入和文档 - 在projects端点添加对新AgentTask模型的统计支持 - 简化工作流描述为START→Orchestrator→[Recon/Analysis/Verification]→Report→END 新架构使用OrchestratorAgent作为编排层,动态调度子Agent完成任务,提高灵活性和可扩展性。
This commit is contained in:
parent
39e2f43210
commit
15605fea16
|
|
@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
|
||||||
from app.models.project import Project
|
from app.models.project import Project
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.audit import AuditTask, AuditIssue
|
from app.models.audit import AuditTask, AuditIssue
|
||||||
|
from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
|
||||||
from app.models.user_config import UserConfig
|
from app.models.user_config import UserConfig
|
||||||
import zipfile
|
import zipfile
|
||||||
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
|
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
|
||||||
|
|
@ -162,26 +163,51 @@ async def get_stats(
|
||||||
projects = projects_result.scalars().all()
|
projects = projects_result.scalars().all()
|
||||||
project_ids = [p.id for p in projects]
|
project_ids = [p.id for p in projects]
|
||||||
|
|
||||||
# 只统计当前用户项目的任务
|
# 统计旧的 AuditTask
|
||||||
tasks_result = await db.execute(
|
tasks_result = await db.execute(
|
||||||
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
|
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
|
||||||
)
|
)
|
||||||
tasks = tasks_result.scalars().all()
|
tasks = tasks_result.scalars().all()
|
||||||
task_ids = [t.id for t in tasks]
|
task_ids = [t.id for t in tasks]
|
||||||
|
|
||||||
# 只统计当前用户任务的问题
|
# 统计旧的 AuditIssue
|
||||||
issues_result = await db.execute(
|
issues_result = await db.execute(
|
||||||
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
|
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
|
||||||
)
|
)
|
||||||
issues = issues_result.scalars().all()
|
issues = issues_result.scalars().all()
|
||||||
|
|
||||||
|
# 🔥 同时统计新的 AgentTask
|
||||||
|
agent_tasks_result = await db.execute(
|
||||||
|
select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
|
||||||
|
)
|
||||||
|
agent_tasks = agent_tasks_result.scalars().all()
|
||||||
|
agent_task_ids = [t.id for t in agent_tasks]
|
||||||
|
|
||||||
|
# 🔥 统计 AgentFinding
|
||||||
|
agent_findings_result = await db.execute(
|
||||||
|
select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
|
||||||
|
)
|
||||||
|
agent_findings = agent_findings_result.scalars().all()
|
||||||
|
|
||||||
|
# 合并统计(旧任务 + 新 Agent 任务)
|
||||||
|
total_tasks = len(tasks) + len(agent_tasks)
|
||||||
|
completed_tasks = (
|
||||||
|
len([t for t in tasks if t.status == "completed"]) +
|
||||||
|
len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
|
||||||
|
)
|
||||||
|
total_issues = len(issues) + len(agent_findings)
|
||||||
|
resolved_issues = (
|
||||||
|
len([i for i in issues if i.status == "resolved"]) +
|
||||||
|
len([f for f in agent_findings if f.status == "resolved"])
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_projects": len(projects),
|
"total_projects": len(projects),
|
||||||
"active_projects": len([p for p in projects if p.is_active]),
|
"active_projects": len([p for p in projects if p.is_active]),
|
||||||
"total_tasks": len(tasks),
|
"total_tasks": total_tasks,
|
||||||
"completed_tasks": len([t for t in tasks if t.status == "completed"]),
|
"completed_tasks": completed_tasks,
|
||||||
"total_issues": len(issues),
|
"total_issues": total_issues,
|
||||||
"resolved_issues": len([i for i in issues if i.status == "resolved"]),
|
"resolved_issues": resolved_issues,
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.get("/{id}", response_model=ProjectResponse)
|
@router.get("/{id}", response_model=ProjectResponse)
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,19 @@
|
||||||
"""
|
"""
|
||||||
DeepAudit Agent 服务模块
|
DeepAudit Agent 服务模块
|
||||||
基于 LangGraph 的 AI Agent 代码安全审计
|
基于动态 Agent 树架构的 AI 代码安全审计
|
||||||
|
|
||||||
架构升级版本 - 支持:
|
架构:
|
||||||
- 动态Agent树结构
|
- OrchestratorAgent 作为编排层,动态调度子 Agent
|
||||||
- 专业知识模块系统
|
- ReconAgent 负责侦察和文件分析
|
||||||
- Agent间通信机制
|
- AnalysisAgent 负责漏洞分析
|
||||||
- 完整状态管理
|
- VerificationAgent 负责验证发现
|
||||||
- Think工具和漏洞报告工具
|
|
||||||
|
|
||||||
工作流:
|
工作流:
|
||||||
START → Recon → Analysis ⟲ → Verification → Report → END
|
START → Orchestrator → [Recon/Analysis/Verification] → Report → END
|
||||||
|
|
||||||
支持动态创建子Agent进行专业化分析
|
支持动态创建子Agent进行专业化分析
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 从 graph 模块导入主要组件
|
|
||||||
from .graph import (
|
|
||||||
AgentRunner,
|
|
||||||
run_agent_task,
|
|
||||||
LLMService,
|
|
||||||
AuditState,
|
|
||||||
create_audit_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 事件管理
|
# 事件管理
|
||||||
from .event_manager import EventManager, AgentEventEmitter
|
from .event_manager import EventManager, AgentEventEmitter
|
||||||
|
|
||||||
|
|
@ -33,14 +23,14 @@ from .agents import (
|
||||||
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:核心模块(状态管理、注册表、消息)
|
# 核心模块(状态管理、注册表、消息)
|
||||||
from .core import (
|
from .core import (
|
||||||
AgentState, AgentStatus,
|
AgentState, AgentStatus,
|
||||||
AgentRegistry, agent_registry,
|
AgentRegistry, agent_registry,
|
||||||
AgentMessage, MessageType, MessagePriority, MessageBus,
|
AgentMessage, MessageType, MessagePriority, MessageBus,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:知识模块系统(基于RAG)
|
# 知识模块系统(基于RAG)
|
||||||
from .knowledge import (
|
from .knowledge import (
|
||||||
KnowledgeLoader, knowledge_loader,
|
KnowledgeLoader, knowledge_loader,
|
||||||
get_available_modules, get_module_content,
|
get_available_modules, get_module_content,
|
||||||
|
|
@ -48,7 +38,7 @@ from .knowledge import (
|
||||||
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:协作工具
|
# 协作工具
|
||||||
from .tools import (
|
from .tools import (
|
||||||
ThinkTool, ReflectTool,
|
ThinkTool, ReflectTool,
|
||||||
CreateVulnerabilityReportTool,
|
CreateVulnerabilityReportTool,
|
||||||
|
|
@ -57,20 +47,11 @@ from .tools import (
|
||||||
WaitForMessageTool, AgentFinishTool,
|
WaitForMessageTool, AgentFinishTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 🔥 新增:遥测模块
|
# 遥测模块
|
||||||
from .telemetry import Tracer, get_global_tracer, set_global_tracer
|
from .telemetry import Tracer, get_global_tracer, set_global_tracer
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 核心 Runner
|
|
||||||
"AgentRunner",
|
|
||||||
"run_agent_task",
|
|
||||||
"LLMService",
|
|
||||||
|
|
||||||
# LangGraph
|
|
||||||
"AuditState",
|
|
||||||
"create_audit_graph",
|
|
||||||
|
|
||||||
# 事件管理
|
# 事件管理
|
||||||
"EventManager",
|
"EventManager",
|
||||||
"AgentEventEmitter",
|
"AgentEventEmitter",
|
||||||
|
|
@ -84,7 +65,7 @@ __all__ = [
|
||||||
"AnalysisAgent",
|
"AnalysisAgent",
|
||||||
"VerificationAgent",
|
"VerificationAgent",
|
||||||
|
|
||||||
# 🔥 核心模块
|
# 核心模块
|
||||||
"AgentState",
|
"AgentState",
|
||||||
"AgentStatus",
|
"AgentStatus",
|
||||||
"AgentRegistry",
|
"AgentRegistry",
|
||||||
|
|
@ -94,7 +75,7 @@ __all__ = [
|
||||||
"MessagePriority",
|
"MessagePriority",
|
||||||
"MessageBus",
|
"MessageBus",
|
||||||
|
|
||||||
# 🔥 知识模块(基于RAG)
|
# 知识模块(基于RAG)
|
||||||
"KnowledgeLoader",
|
"KnowledgeLoader",
|
||||||
"knowledge_loader",
|
"knowledge_loader",
|
||||||
"get_available_modules",
|
"get_available_modules",
|
||||||
|
|
@ -104,7 +85,7 @@ __all__ = [
|
||||||
"SecurityKnowledgeQueryTool",
|
"SecurityKnowledgeQueryTool",
|
||||||
"GetVulnerabilityKnowledgeTool",
|
"GetVulnerabilityKnowledgeTool",
|
||||||
|
|
||||||
# 🔥 协作工具
|
# 协作工具
|
||||||
"ThinkTool",
|
"ThinkTool",
|
||||||
"ReflectTool",
|
"ReflectTool",
|
||||||
"CreateVulnerabilityReportTool",
|
"CreateVulnerabilityReportTool",
|
||||||
|
|
@ -115,9 +96,8 @@ __all__ = [
|
||||||
"WaitForMessageTool",
|
"WaitForMessageTool",
|
||||||
"AgentFinishTool",
|
"AgentFinishTool",
|
||||||
|
|
||||||
# 🔥 遥测模块
|
# 遥测模块
|
||||||
"Tracer",
|
"Tracer",
|
||||||
"get_global_tracer",
|
"get_global_tracer",
|
||||||
"set_global_tracer",
|
"set_global_tracer",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
"""
|
|
||||||
LangGraph 工作流模块
|
|
||||||
使用状态图构建混合 Agent 审计流程
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
|
|
||||||
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
|
|
||||||
from .runner import AgentRunner, run_agent_task, LLMService
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# 状态和图
|
|
||||||
"AuditState",
|
|
||||||
"create_audit_graph",
|
|
||||||
"create_audit_graph_with_human",
|
|
||||||
|
|
||||||
# 节点
|
|
||||||
"ReconNode",
|
|
||||||
"AnalysisNode",
|
|
||||||
"VerificationNode",
|
|
||||||
"ReportNode",
|
|
||||||
"HumanReviewNode",
|
|
||||||
|
|
||||||
# Runner
|
|
||||||
"AgentRunner",
|
|
||||||
"run_agent_task",
|
|
||||||
"LLMService",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
@ -1,677 +0,0 @@
|
||||||
"""
|
|
||||||
DeepAudit 审计工作流图 - LLM 驱动版
|
|
||||||
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
|
|
||||||
|
|
||||||
重要改变:路由决策由 LLM 参与,而不是硬编码条件!
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
|
|
||||||
from datetime import datetime
|
|
||||||
import operator
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, END
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 状态定义 ============
|
|
||||||
|
|
||||||
class Finding(TypedDict):
|
|
||||||
"""漏洞发现"""
|
|
||||||
id: str
|
|
||||||
vulnerability_type: str
|
|
||||||
severity: str
|
|
||||||
title: str
|
|
||||||
description: str
|
|
||||||
file_path: Optional[str]
|
|
||||||
line_start: Optional[int]
|
|
||||||
code_snippet: Optional[str]
|
|
||||||
is_verified: bool
|
|
||||||
confidence: float
|
|
||||||
source: str
|
|
||||||
|
|
||||||
|
|
||||||
class AuditState(TypedDict):
|
|
||||||
"""
|
|
||||||
审计状态
|
|
||||||
在整个工作流中传递和更新
|
|
||||||
"""
|
|
||||||
# 输入
|
|
||||||
project_root: str
|
|
||||||
project_info: Dict[str, Any]
|
|
||||||
config: Dict[str, Any]
|
|
||||||
task_id: str
|
|
||||||
|
|
||||||
# Recon 阶段输出
|
|
||||||
tech_stack: Dict[str, Any]
|
|
||||||
entry_points: List[Dict[str, Any]]
|
|
||||||
high_risk_areas: List[str]
|
|
||||||
dependencies: Dict[str, Any]
|
|
||||||
|
|
||||||
# Analysis 阶段输出
|
|
||||||
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
|
|
||||||
|
|
||||||
# Verification 阶段输出
|
|
||||||
verified_findings: List[Finding]
|
|
||||||
false_positives: List[str]
|
|
||||||
# 🔥 NEW: 验证后的完整 findings(用于替换原始 findings)
|
|
||||||
_verified_findings_update: Optional[List[Finding]]
|
|
||||||
|
|
||||||
# 控制流 - 🔥 关键:LLM 可以设置这些来影响路由
|
|
||||||
current_phase: str
|
|
||||||
iteration: int
|
|
||||||
max_iterations: int
|
|
||||||
should_continue_analysis: bool
|
|
||||||
|
|
||||||
# 🔥 新增:LLM 的路由决策
|
|
||||||
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
|
|
||||||
llm_routing_reason: Optional[str] # LLM 的决策理由
|
|
||||||
|
|
||||||
# 🔥 新增:Agent 间协作的任务交接信息
|
|
||||||
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
|
|
||||||
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
|
|
||||||
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
|
|
||||||
|
|
||||||
# 消息和事件
|
|
||||||
messages: Annotated[List[Dict], operator.add]
|
|
||||||
events: Annotated[List[Dict], operator.add]
|
|
||||||
|
|
||||||
# 最终输出
|
|
||||||
summary: Optional[Dict[str, Any]]
|
|
||||||
security_score: Optional[int]
|
|
||||||
error: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
# ============ LLM 路由决策器 ============
|
|
||||||
|
|
||||||
class LLMRouter:
|
|
||||||
"""
|
|
||||||
LLM 路由决策器
|
|
||||||
让 LLM 来决定下一步应该做什么
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, llm_service):
|
|
||||||
self.llm_service = llm_service
|
|
||||||
|
|
||||||
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Recon 后让 LLM 决定下一步"""
|
|
||||||
entry_points = state.get("entry_points", [])
|
|
||||||
high_risk_areas = state.get("high_risk_areas", [])
|
|
||||||
tech_stack = state.get("tech_stack", {})
|
|
||||||
initial_findings = state.get("findings", [])
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 信息收集结果
|
|
||||||
- 入口点数量: {len(entry_points)}
|
|
||||||
- 高风险区域: {high_risk_areas[:10]}
|
|
||||||
- 技术栈: {tech_stack}
|
|
||||||
- 初步发现: {len(initial_findings)} 个
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时)
|
|
||||||
2. "end" - 结束审计(仅当没有任何可分析内容时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "analysis或end", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
# 提取 JSON
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if entry_points or high_risk_areas:
|
|
||||||
return {"action": "analysis", "reason": "有可分析内容"}
|
|
||||||
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
|
|
||||||
|
|
||||||
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Analysis 后让 LLM 决定下一步"""
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
# 统计发现
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
for f in findings:
|
|
||||||
# 跳过非字典类型的 finding
|
|
||||||
if not isinstance(f, dict):
|
|
||||||
continue
|
|
||||||
sev = f.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 分析结果
|
|
||||||
- 总发现数: {len(findings)}
|
|
||||||
- 严重程度分布: {severity_counts}
|
|
||||||
- 当前迭代: {iteration}/{max_iterations}
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时)
|
|
||||||
2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时)
|
|
||||||
3. "report" - 生成报告(推荐:没有发现或已充分分析时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if not findings:
|
|
||||||
return {"action": "report", "reason": "没有发现漏洞"}
|
|
||||||
if len(findings) >= 3 or iteration >= max_iterations:
|
|
||||||
return {"action": "verification", "reason": "有足够的发现需要验证"}
|
|
||||||
return {"action": "analysis", "reason": "发现较少,继续分析"}
|
|
||||||
|
|
||||||
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
|
|
||||||
"""Verification 后让 LLM 决定下一步"""
|
|
||||||
verified_findings = state.get("verified_findings", [])
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
|
|
||||||
|
|
||||||
## 验证结果
|
|
||||||
- 已确认漏洞: {len(verified_findings)}
|
|
||||||
- 误报数量: {len(false_positives)}
|
|
||||||
- 当前迭代: {iteration}/{max_iterations}
|
|
||||||
|
|
||||||
## 选项
|
|
||||||
1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时)
|
|
||||||
2. "report" - 生成最终报告(推荐:验证完成时)
|
|
||||||
|
|
||||||
请返回 JSON 格式:
|
|
||||||
{{"action": "analysis/report", "reason": "决策理由"}}"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.llm_service.chat_completion_raw(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
# 🔥 不传递 temperature 和 max_tokens,使用用户配置
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.get("content", "")
|
|
||||||
import re
|
|
||||||
json_match = re.search(r'\{.*\}', content, re.DOTALL)
|
|
||||||
if json_match:
|
|
||||||
result = json.loads(json_match.group())
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM routing decision failed: {e}")
|
|
||||||
|
|
||||||
# 默认决策
|
|
||||||
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
|
|
||||||
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
|
|
||||||
return {"action": "report", "reason": "验证完成,生成报告"}
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 路由函数 (结合 LLM 决策) ============
|
|
||||||
|
|
||||||
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
|
||||||
"""
|
|
||||||
Recon 后的路由决策
|
|
||||||
优先使用 LLM 的决策,否则使用默认逻辑
|
|
||||||
"""
|
|
||||||
# 🔥 检查是否有错误
|
|
||||||
if state.get("error") or state.get("current_phase") == "error":
|
|
||||||
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
|
|
||||||
return "end"
|
|
||||||
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "end":
|
|
||||||
return "end"
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
# 默认逻辑(作为 fallback)
|
|
||||||
if not state.get("entry_points") and not state.get("high_risk_areas"):
|
|
||||||
return "end"
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
|
|
||||||
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
|
|
||||||
"""
|
|
||||||
Analysis 后的路由决策
|
|
||||||
优先使用 LLM 的决策
|
|
||||||
"""
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "verification":
|
|
||||||
return "verification"
|
|
||||||
elif llm_action == "analysis":
|
|
||||||
return "analysis"
|
|
||||||
elif llm_action == "report":
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
# 默认逻辑
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
should_continue = state.get("should_continue_analysis", False)
|
|
||||||
|
|
||||||
if not findings:
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
if should_continue and iteration < max_iterations:
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
return "verification"
|
|
||||||
|
|
||||||
|
|
||||||
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
|
|
||||||
"""
|
|
||||||
Verification 后的路由决策
|
|
||||||
优先使用 LLM 的决策
|
|
||||||
"""
|
|
||||||
# 检查 LLM 是否有决策
|
|
||||||
llm_action = state.get("llm_next_action")
|
|
||||||
if llm_action:
|
|
||||||
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
|
|
||||||
if llm_action == "analysis":
|
|
||||||
return "analysis"
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
# 默认逻辑
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
iteration = state.get("iteration", 0)
|
|
||||||
max_iterations = state.get("max_iterations", 3)
|
|
||||||
|
|
||||||
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
|
|
||||||
return "analysis"
|
|
||||||
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 创建审计图 ============
|
|
||||||
|
|
||||||
def create_audit_graph(
|
|
||||||
recon_node,
|
|
||||||
analysis_node,
|
|
||||||
verification_node,
|
|
||||||
report_node,
|
|
||||||
checkpointer: Optional[MemorySaver] = None,
|
|
||||||
llm_service=None, # 用于 LLM 路由决策
|
|
||||||
) -> StateGraph:
|
|
||||||
"""
|
|
||||||
创建审计工作流图
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recon_node: 信息收集节点
|
|
||||||
analysis_node: 漏洞分析节点
|
|
||||||
verification_node: 漏洞验证节点
|
|
||||||
report_node: 报告生成节点
|
|
||||||
checkpointer: 检查点存储器
|
|
||||||
llm_service: LLM 服务(用于路由决策)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
编译后的 StateGraph
|
|
||||||
|
|
||||||
工作流结构:
|
|
||||||
|
|
||||||
START
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌──────┐
|
|
||||||
│Recon │ 信息收集 (LLM 驱动)
|
|
||||||
└──┬───┘
|
|
||||||
│ LLM 决定
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
│ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环)
|
|
||||||
└────┬─────┘ │
|
|
||||||
│ LLM 决定 │
|
|
||||||
▼ │
|
|
||||||
┌────────────┐ │
|
|
||||||
│Verification│────┘ 漏洞验证 (LLM 驱动,可回溯)
|
|
||||||
└─────┬──────┘
|
|
||||||
│ LLM 决定
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
│ Report │ 报告生成
|
|
||||||
└────┬─────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
END
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 创建状态图
|
|
||||||
workflow = StateGraph(AuditState)
|
|
||||||
|
|
||||||
# 如果有 LLM 服务,创建路由决策器
|
|
||||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
|
||||||
|
|
||||||
# 包装节点以添加 LLM 路由决策
|
|
||||||
async def recon_with_routing(state):
|
|
||||||
result = await recon_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_recon({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def analysis_with_routing(state):
|
|
||||||
result = await analysis_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def verification_with_routing(state):
|
|
||||||
result = await verification_node(state)
|
|
||||||
|
|
||||||
# LLM 决定下一步
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_verification({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
if llm_router:
|
|
||||||
workflow.add_node("recon", recon_with_routing)
|
|
||||||
workflow.add_node("analysis", analysis_with_routing)
|
|
||||||
workflow.add_node("verification", verification_with_routing)
|
|
||||||
else:
|
|
||||||
workflow.add_node("recon", recon_node)
|
|
||||||
workflow.add_node("analysis", analysis_node)
|
|
||||||
workflow.add_node("verification", verification_node)
|
|
||||||
|
|
||||||
workflow.add_node("report", report_node)
|
|
||||||
|
|
||||||
# 设置入口点
|
|
||||||
workflow.set_entry_point("recon")
|
|
||||||
|
|
||||||
# 添加条件边
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"recon",
|
|
||||||
route_after_recon,
|
|
||||||
{
|
|
||||||
"analysis": "analysis",
|
|
||||||
"end": END,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"analysis",
|
|
||||||
route_after_analysis,
|
|
||||||
{
|
|
||||||
"verification": "verification",
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"verification",
|
|
||||||
route_after_verification,
|
|
||||||
{
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Report -> END
|
|
||||||
workflow.add_edge("report", END)
|
|
||||||
|
|
||||||
# 编译图
|
|
||||||
if checkpointer:
|
|
||||||
return workflow.compile(checkpointer=checkpointer)
|
|
||||||
else:
|
|
||||||
return workflow.compile()
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 带人机协作的审计图 ============
|
|
||||||
|
|
||||||
def create_audit_graph_with_human(
|
|
||||||
recon_node,
|
|
||||||
analysis_node,
|
|
||||||
verification_node,
|
|
||||||
report_node,
|
|
||||||
human_review_node,
|
|
||||||
checkpointer: Optional[MemorySaver] = None,
|
|
||||||
llm_service=None,
|
|
||||||
) -> StateGraph:
|
|
||||||
"""
|
|
||||||
创建带人机协作的审计工作流图
|
|
||||||
|
|
||||||
在验证阶段后增加人工审核节点
|
|
||||||
"""
|
|
||||||
|
|
||||||
workflow = StateGraph(AuditState)
|
|
||||||
llm_router = LLMRouter(llm_service) if llm_service else None
|
|
||||||
|
|
||||||
# 包装节点
|
|
||||||
async def recon_with_routing(state):
|
|
||||||
result = await recon_node(state)
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_recon({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def analysis_with_routing(state):
|
|
||||||
result = await analysis_node(state)
|
|
||||||
if llm_router:
|
|
||||||
decision = await llm_router.decide_after_analysis({**state, **result})
|
|
||||||
result["llm_next_action"] = decision.get("action")
|
|
||||||
result["llm_routing_reason"] = decision.get("reason")
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
if llm_router:
|
|
||||||
workflow.add_node("recon", recon_with_routing)
|
|
||||||
workflow.add_node("analysis", analysis_with_routing)
|
|
||||||
else:
|
|
||||||
workflow.add_node("recon", recon_node)
|
|
||||||
workflow.add_node("analysis", analysis_node)
|
|
||||||
|
|
||||||
workflow.add_node("verification", verification_node)
|
|
||||||
workflow.add_node("human_review", human_review_node)
|
|
||||||
workflow.add_node("report", report_node)
|
|
||||||
|
|
||||||
workflow.set_entry_point("recon")
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"recon",
|
|
||||||
route_after_recon,
|
|
||||||
{"analysis": "analysis", "end": END}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"analysis",
|
|
||||||
route_after_analysis,
|
|
||||||
{
|
|
||||||
"verification": "verification",
|
|
||||||
"analysis": "analysis",
|
|
||||||
"report": "report",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verification -> Human Review
|
|
||||||
workflow.add_edge("verification", "human_review")
|
|
||||||
|
|
||||||
# Human Review 后的路由
|
|
||||||
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
|
|
||||||
if state.get("should_continue_analysis"):
|
|
||||||
return "analysis"
|
|
||||||
return "report"
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"human_review",
|
|
||||||
route_after_human,
|
|
||||||
{"analysis": "analysis", "report": "report"}
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.add_edge("report", END)
|
|
||||||
|
|
||||||
if checkpointer:
|
|
||||||
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
|
|
||||||
else:
|
|
||||||
return workflow.compile()
|
|
||||||
|
|
||||||
|
|
||||||
# ============ 执行器 ============
|
|
||||||
|
|
||||||
class AuditGraphRunner:
|
|
||||||
"""
|
|
||||||
审计图执行器
|
|
||||||
封装 LangGraph 工作流的执行
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
graph: StateGraph,
|
|
||||||
event_emitter=None,
|
|
||||||
):
|
|
||||||
self.graph = graph
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
project_root: str,
|
|
||||||
project_info: Dict[str, Any],
|
|
||||||
config: Dict[str, Any],
|
|
||||||
task_id: str,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
执行审计工作流
|
|
||||||
"""
|
|
||||||
# 初始状态
|
|
||||||
initial_state: AuditState = {
|
|
||||||
"project_root": project_root,
|
|
||||||
"project_info": project_info,
|
|
||||||
"config": config,
|
|
||||||
"task_id": task_id,
|
|
||||||
"tech_stack": {},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": [],
|
|
||||||
"dependencies": {},
|
|
||||||
"findings": [],
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
|
|
||||||
"current_phase": "start",
|
|
||||||
"iteration": 0,
|
|
||||||
"max_iterations": config.get("max_iterations", 3),
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"llm_next_action": None,
|
|
||||||
"llm_routing_reason": None,
|
|
||||||
"messages": [],
|
|
||||||
"events": [],
|
|
||||||
"summary": None,
|
|
||||||
"security_score": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
run_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": task_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for event in self.graph.astream(initial_state, config=run_config):
|
|
||||||
if self.event_emitter:
|
|
||||||
for node_name, node_state in event.items():
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"节点 {node_name} 完成"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 发射 LLM 路由决策事件
|
|
||||||
if node_state.get("llm_routing_reason"):
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if node_name == "analysis" and node_state.get("findings"):
|
|
||||||
new_findings = node_state["findings"]
|
|
||||||
await self.event_emitter.emit_info(
|
|
||||||
f"发现 {len(new_findings)} 个潜在漏洞"
|
|
||||||
)
|
|
||||||
|
|
||||||
final_state = self.graph.get_state(run_config)
|
|
||||||
return final_state.values
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Graph execution failed: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def run_with_human_review(
|
|
||||||
self,
|
|
||||||
initial_state: AuditState,
|
|
||||||
human_feedback_callback,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""带人机协作的执行"""
|
|
||||||
run_config = {
|
|
||||||
"configurable": {
|
|
||||||
"thread_id": initial_state["task_id"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async for event in self.graph.astream(initial_state, config=run_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
current_state = self.graph.get_state(run_config)
|
|
||||||
|
|
||||||
if current_state.next == ("human_review",):
|
|
||||||
human_decision = await human_feedback_callback(current_state.values)
|
|
||||||
|
|
||||||
updated_state = {
|
|
||||||
**current_state.values,
|
|
||||||
"should_continue_analysis": human_decision.get("continue_analysis", False),
|
|
||||||
}
|
|
||||||
|
|
||||||
async for event in self.graph.astream(updated_state, config=run_config):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return self.graph.get_state(run_config).values
|
|
||||||
|
|
@ -1,556 +0,0 @@
|
||||||
"""
|
|
||||||
LangGraph 节点实现
|
|
||||||
每个节点封装一个 Agent 的执行逻辑
|
|
||||||
|
|
||||||
协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 延迟导入避免循环依赖
|
|
||||||
def get_audit_state_type():
|
|
||||||
from .audit_graph import AuditState
|
|
||||||
return AuditState
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNode:
|
|
||||||
"""节点基类"""
|
|
||||||
|
|
||||||
def __init__(self, agent=None, event_emitter=None):
|
|
||||||
self.agent = agent
|
|
||||||
self.event_emitter = event_emitter
|
|
||||||
|
|
||||||
async def emit_event(self, event_type: str, message: str, **kwargs):
|
|
||||||
"""发射事件"""
|
|
||||||
if self.event_emitter:
|
|
||||||
try:
|
|
||||||
await self.event_emitter.emit_info(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to emit event: {e}")
|
|
||||||
|
|
||||||
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
|
|
||||||
"""从状态中提取前序 Agent 的 handoff"""
|
|
||||||
handoff_data = state.get(f"{from_phase}_handoff")
|
|
||||||
if handoff_data:
|
|
||||||
from ..agents.base import TaskHandoff
|
|
||||||
return TaskHandoff.from_dict(handoff_data)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ReconNode(BaseNode):
|
|
||||||
"""
|
|
||||||
信息收集节点
|
|
||||||
|
|
||||||
输入: project_root, project_info, config
|
|
||||||
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行信息收集"""
|
|
||||||
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 Recon Agent
|
|
||||||
result = await self.agent.run({
|
|
||||||
"project_info": state["project_info"],
|
|
||||||
"config": state["config"],
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
data = result.data
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Analysis Agent
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Analysis",
|
|
||||||
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
|
|
||||||
key_findings=data.get("initial_findings", []),
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "deep_analysis",
|
|
||||||
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "entry_point_audit",
|
|
||||||
"description": "审计所有入口点的输入验证",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
|
|
||||||
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
|
|
||||||
],
|
|
||||||
priority_areas=data.get("high_risk_areas", [])[:10],
|
|
||||||
context_data={
|
|
||||||
"tech_stack": data.get("tech_stack", {}),
|
|
||||||
"entry_points": data.get("entry_points", []),
|
|
||||||
"dependencies": data.get("dependencies", {}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"tech_stack": data.get("tech_stack", {}),
|
|
||||||
"entry_points": data.get("entry_points", []),
|
|
||||||
"high_risk_areas": data.get("high_risk_areas", []),
|
|
||||||
"dependencies": data.get("dependencies", {}),
|
|
||||||
"current_phase": "recon_complete",
|
|
||||||
"findings": data.get("initial_findings", []),
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"recon_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "recon_complete",
|
|
||||||
"data": {
|
|
||||||
"entry_points_count": len(data.get("entry_points", [])),
|
|
||||||
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"error": result.error or "Recon failed",
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Recon node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisNode(BaseNode):
|
|
||||||
"""
|
|
||||||
漏洞分析节点
|
|
||||||
|
|
||||||
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
|
|
||||||
输出: findings (累加), should_continue_analysis, analysis_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行漏洞分析"""
|
|
||||||
iteration = state.get("iteration", 0) + 1
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_start",
|
|
||||||
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 提取 Recon 的交接信息
|
|
||||||
recon_handoff = self._extract_handoff_from_state(state, "recon")
|
|
||||||
if recon_handoff:
|
|
||||||
self.agent.receive_handoff(recon_handoff)
|
|
||||||
await self.emit_event(
|
|
||||||
"handoff_received",
|
|
||||||
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建分析输入
|
|
||||||
analysis_input = {
|
|
||||||
"phase_name": "analysis",
|
|
||||||
"project_info": state["project_info"],
|
|
||||||
"config": state["config"],
|
|
||||||
"plan": {
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
},
|
|
||||||
"previous_results": {
|
|
||||||
"recon": {
|
|
||||||
"data": {
|
|
||||||
"tech_stack": state.get("tech_stack", {}),
|
|
||||||
"entry_points": state.get("entry_points", []),
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
# 🔥 传递交接信息
|
|
||||||
"handoff": recon_handoff,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用 Analysis Agent
|
|
||||||
result = await self.agent.run(analysis_input)
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
new_findings = result.data.get("findings", [])
|
|
||||||
logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
|
|
||||||
|
|
||||||
# 判断是否需要继续分析
|
|
||||||
should_continue = (
|
|
||||||
len(new_findings) >= 5 and
|
|
||||||
iteration < state.get("max_iterations", 3)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Verification Agent
|
|
||||||
# 统计严重程度
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
for f in new_findings:
|
|
||||||
if isinstance(f, dict):
|
|
||||||
sev = f.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Verification",
|
|
||||||
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
|
|
||||||
key_findings=new_findings[:20], # 传递前20个发现
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "verify_critical",
|
|
||||||
"description": "优先验证 Critical 和 High 级别的漏洞",
|
|
||||||
"priority": "critical",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "poc_generation",
|
|
||||||
"description": "为确认的漏洞生成 PoC",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
|
|
||||||
f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证",
|
|
||||||
"注意检查是否有误报,特别是静态分析工具的结果",
|
|
||||||
],
|
|
||||||
priority_areas=[
|
|
||||||
f.get("file_path", "") for f in new_findings
|
|
||||||
if f.get("severity") in ["critical", "high"]
|
|
||||||
][:10],
|
|
||||||
context_data={
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"total_findings": len(new_findings),
|
|
||||||
"iteration": iteration,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"findings": new_findings,
|
|
||||||
"iteration": iteration,
|
|
||||||
"should_continue_analysis": should_continue,
|
|
||||||
"current_phase": "analysis_complete",
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"analysis_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "analysis_iteration",
|
|
||||||
"data": {
|
|
||||||
"iteration": iteration,
|
|
||||||
"findings_count": len(new_findings),
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"iteration": iteration,
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"current_phase": "analysis_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Analysis node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class VerificationNode(BaseNode):
|
|
||||||
"""
|
|
||||||
漏洞验证节点
|
|
||||||
|
|
||||||
输入: findings, analysis_handoff
|
|
||||||
输出: verified_findings, false_positives, verification_handoff
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""执行漏洞验证"""
|
|
||||||
findings = state.get("findings", [])
|
|
||||||
logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
|
|
||||||
|
|
||||||
if not findings:
|
|
||||||
return {
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_start",
|
|
||||||
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 提取 Analysis 的交接信息
|
|
||||||
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
|
|
||||||
if analysis_handoff:
|
|
||||||
self.agent.receive_handoff(analysis_handoff)
|
|
||||||
await self.emit_event(
|
|
||||||
"handoff_received",
|
|
||||||
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建验证输入
|
|
||||||
verification_input = {
|
|
||||||
"previous_results": {
|
|
||||||
"analysis": {
|
|
||||||
"data": {
|
|
||||||
"findings": findings,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"config": state["config"],
|
|
||||||
# 🔥 传递交接信息
|
|
||||||
"handoff": analysis_handoff,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 调用 Verification Agent
|
|
||||||
result = await self.agent.run(verification_input)
|
|
||||||
|
|
||||||
if result.success and result.data:
|
|
||||||
all_verified_findings = result.data.get("findings", [])
|
|
||||||
verified = [f for f in all_verified_findings if f.get("is_verified")]
|
|
||||||
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
|
|
||||||
if f.get("verdict") == "false_positive"]
|
|
||||||
|
|
||||||
# 🔥 CRITICAL FIX: 用验证结果更新原始 findings
|
|
||||||
# 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
|
|
||||||
verified_map = {}
|
|
||||||
for vf in all_verified_findings:
|
|
||||||
key = (
|
|
||||||
vf.get("file_path", ""),
|
|
||||||
vf.get("line_start", 0),
|
|
||||||
vf.get("vulnerability_type", ""),
|
|
||||||
)
|
|
||||||
verified_map[key] = vf
|
|
||||||
|
|
||||||
# 合并验证结果到原始 findings
|
|
||||||
updated_findings = []
|
|
||||||
seen_keys = set()
|
|
||||||
|
|
||||||
# 首先处理原始 findings,用验证结果更新
|
|
||||||
for f in findings:
|
|
||||||
if not isinstance(f, dict):
|
|
||||||
continue
|
|
||||||
key = (
|
|
||||||
f.get("file_path", ""),
|
|
||||||
f.get("line_start", 0),
|
|
||||||
f.get("vulnerability_type", ""),
|
|
||||||
)
|
|
||||||
if key in verified_map:
|
|
||||||
# 使用验证后的版本
|
|
||||||
updated_findings.append(verified_map[key])
|
|
||||||
seen_keys.add(key)
|
|
||||||
else:
|
|
||||||
# 保留原始(未验证)
|
|
||||||
updated_findings.append(f)
|
|
||||||
seen_keys.add(key)
|
|
||||||
|
|
||||||
# 添加验证结果中的新发现(如果有)
|
|
||||||
for key, vf in verified_map.items():
|
|
||||||
if key not in seen_keys:
|
|
||||||
updated_findings.append(vf)
|
|
||||||
|
|
||||||
logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
|
|
||||||
|
|
||||||
# 🔥 创建交接信息给 Report 节点
|
|
||||||
handoff = self.agent.create_handoff(
|
|
||||||
to_agent="Report",
|
|
||||||
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
|
|
||||||
key_findings=verified,
|
|
||||||
suggested_actions=[
|
|
||||||
{
|
|
||||||
"type": "generate_report",
|
|
||||||
"description": "生成详细的安全审计报告",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "remediation_plan",
|
|
||||||
"description": "为确认的漏洞制定修复计划",
|
|
||||||
"priority": "high",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
attention_points=[
|
|
||||||
f"共 {len(verified)} 个漏洞已确认存在",
|
|
||||||
f"共 {len(false_pos)} 个误报已排除",
|
|
||||||
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
|
|
||||||
],
|
|
||||||
context_data={
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_pos),
|
|
||||||
"total_analyzed": len(findings),
|
|
||||||
"verification_rate": len(verified) / len(findings) if findings else 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
# 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings
|
|
||||||
# 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并
|
|
||||||
# 这里我们返回 _verified_findings_update 作为特殊字段
|
|
||||||
"_verified_findings_update": updated_findings,
|
|
||||||
"verified_findings": verified,
|
|
||||||
"false_positives": false_pos,
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
# 🔥 保存交接信息
|
|
||||||
"verification_handoff": handoff.to_dict(),
|
|
||||||
"events": [{
|
|
||||||
"type": "verification_complete",
|
|
||||||
"data": {
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_pos),
|
|
||||||
"total_findings": len(updated_findings),
|
|
||||||
"handoff_summary": handoff.summary,
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "verification_complete",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Verification node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ReportNode(BaseNode):
|
|
||||||
"""
|
|
||||||
报告生成节点
|
|
||||||
|
|
||||||
输入: all state
|
|
||||||
输出: summary, security_score
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""生成审计报告"""
|
|
||||||
await self.emit_event("phase_start", "📊 生成审计报告")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
|
|
||||||
findings = state.get("_verified_findings_update") or state.get("findings", [])
|
|
||||||
verified = state.get("verified_findings", [])
|
|
||||||
false_positives = state.get("false_positives", [])
|
|
||||||
|
|
||||||
logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
|
|
||||||
|
|
||||||
# 统计漏洞分布
|
|
||||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
|
||||||
type_counts = {}
|
|
||||||
|
|
||||||
for finding in findings:
|
|
||||||
# 跳过非字典类型的 finding(防止数据格式异常)
|
|
||||||
if not isinstance(finding, dict):
|
|
||||||
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
sev = finding.get("severity", "medium")
|
|
||||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
|
||||||
|
|
||||||
vtype = finding.get("vulnerability_type", "other")
|
|
||||||
type_counts[vtype] = type_counts.get(vtype, 0) + 1
|
|
||||||
|
|
||||||
# 计算安全评分
|
|
||||||
base_score = 100
|
|
||||||
deductions = (
|
|
||||||
severity_counts["critical"] * 25 +
|
|
||||||
severity_counts["high"] * 15 +
|
|
||||||
severity_counts["medium"] * 8 +
|
|
||||||
severity_counts["low"] * 3
|
|
||||||
)
|
|
||||||
security_score = max(0, base_score - deductions)
|
|
||||||
|
|
||||||
# 生成摘要
|
|
||||||
summary = {
|
|
||||||
"total_findings": len(findings),
|
|
||||||
"verified_count": len(verified),
|
|
||||||
"false_positive_count": len(false_positives),
|
|
||||||
"severity_distribution": severity_counts,
|
|
||||||
"vulnerability_types": type_counts,
|
|
||||||
"tech_stack": state.get("tech_stack", {}),
|
|
||||||
"entry_points_analyzed": len(state.get("entry_points", [])),
|
|
||||||
"high_risk_areas": state.get("high_risk_areas", []),
|
|
||||||
"iterations": state.get("iteration", 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.emit_event(
|
|
||||||
"phase_complete",
|
|
||||||
f"报告生成完成: 安全评分 {security_score}/100"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"summary": summary,
|
|
||||||
"security_score": security_score,
|
|
||||||
"current_phase": "complete",
|
|
||||||
"events": [{
|
|
||||||
"type": "audit_complete",
|
|
||||||
"data": {
|
|
||||||
"security_score": security_score,
|
|
||||||
"total_findings": len(findings),
|
|
||||||
"verified_count": len(verified),
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Report node failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"error": str(e),
|
|
||||||
"current_phase": "error",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class HumanReviewNode(BaseNode):
|
|
||||||
"""
|
|
||||||
人工审核节点
|
|
||||||
|
|
||||||
在此节点暂停,等待人工反馈
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
人工审核节点
|
|
||||||
|
|
||||||
这个节点会被 interrupt_before 暂停
|
|
||||||
用户可以:
|
|
||||||
1. 确认发现
|
|
||||||
2. 标记误报
|
|
||||||
3. 请求重新分析
|
|
||||||
"""
|
|
||||||
await self.emit_event(
|
|
||||||
"human_review",
|
|
||||||
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 返回当前状态,不做修改
|
|
||||||
# 人工反馈会通过 update_state 传入
|
|
||||||
return {
|
|
||||||
"current_phase": "human_review",
|
|
||||||
"messages": [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "等待人工审核",
|
|
||||||
"findings_for_review": state.get("verified_findings", []),
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,355 +0,0 @@
|
||||||
"""
|
|
||||||
Agent 集成测试
|
|
||||||
测试完整的审计流程
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, AsyncMock, patch
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.services.agent.graph.runner import AgentRunner, LLMService
|
|
||||||
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
|
|
||||||
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
|
||||||
from app.services.agent.event_manager import EventManager, AgentEventEmitter
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMService:
|
|
||||||
"""LLM 服务测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_llm_service_initialization(self):
|
|
||||||
"""测试 LLM 服务初始化"""
|
|
||||||
with patch("app.core.config.settings") as mock_settings:
|
|
||||||
mock_settings.LLM_MODEL = "gpt-4o-mini"
|
|
||||||
mock_settings.LLM_API_KEY = "test-key"
|
|
||||||
|
|
||||||
service = LLMService()
|
|
||||||
|
|
||||||
assert service.model == "gpt-4o-mini"
|
|
||||||
|
|
||||||
|
|
||||||
class TestEventManager:
|
|
||||||
"""事件管理器测试"""
|
|
||||||
|
|
||||||
def test_event_manager_initialization(self):
|
|
||||||
"""测试事件管理器初始化"""
|
|
||||||
manager = EventManager()
|
|
||||||
|
|
||||||
assert manager._event_queues == {}
|
|
||||||
assert manager._event_callbacks == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter(self):
|
|
||||||
"""测试事件发射器"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_info("Test message")
|
|
||||||
|
|
||||||
assert emitter._sequence == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter_phase_tracking(self):
|
|
||||||
"""测试事件发射器阶段跟踪"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_phase_start("recon", "开始信息收集")
|
|
||||||
|
|
||||||
assert emitter._current_phase == "recon"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_event_emitter_task_complete(self):
|
|
||||||
"""测试任务完成事件"""
|
|
||||||
manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter("test-task-id", manager)
|
|
||||||
|
|
||||||
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
|
|
||||||
|
|
||||||
assert emitter._sequence == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuditGraph:
|
|
||||||
"""审计图测试"""
|
|
||||||
|
|
||||||
def test_create_audit_graph(self, mock_event_emitter):
|
|
||||||
"""测试创建审计图"""
|
|
||||||
# 创建模拟节点
|
|
||||||
recon_node = MagicMock()
|
|
||||||
analysis_node = MagicMock()
|
|
||||||
verification_node = MagicMock()
|
|
||||||
report_node = MagicMock()
|
|
||||||
|
|
||||||
graph = create_audit_graph(
|
|
||||||
recon_node=recon_node,
|
|
||||||
analysis_node=analysis_node,
|
|
||||||
verification_node=verification_node,
|
|
||||||
report_node=report_node,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph is not None
|
|
||||||
|
|
||||||
|
|
||||||
class TestReconNode:
|
|
||||||
"""Recon 节点测试"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def recon_node_with_mock_agent(self, mock_event_emitter):
|
|
||||||
"""创建带模拟 Agent 的 Recon 节点"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=True,
|
|
||||||
data={
|
|
||||||
"tech_stack": {"languages": ["Python"]},
|
|
||||||
"entry_points": [{"path": "src/app.py", "type": "api"}],
|
|
||||||
"high_risk_areas": ["src/sql_vuln.py"],
|
|
||||||
"dependencies": {},
|
|
||||||
"initial_findings": [],
|
|
||||||
}
|
|
||||||
))
|
|
||||||
|
|
||||||
return ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recon_node_success(self, recon_node_with_mock_agent):
|
|
||||||
"""测试 Recon 节点成功执行"""
|
|
||||||
state = {
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"config": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await recon_node_with_mock_agent(state)
|
|
||||||
|
|
||||||
assert "tech_stack" in result
|
|
||||||
assert "entry_points" in result
|
|
||||||
assert result["current_phase"] == "recon_complete"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_recon_node_failure(self, mock_event_emitter):
|
|
||||||
"""测试 Recon 节点失败处理"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=False,
|
|
||||||
error="Test error",
|
|
||||||
data=None,
|
|
||||||
))
|
|
||||||
|
|
||||||
node = ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
result = await node({
|
|
||||||
"project_info": {},
|
|
||||||
"config": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
assert "error" in result
|
|
||||||
assert result["current_phase"] == "error"
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysisNode:
|
|
||||||
"""Analysis 节点测试"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def analysis_node_with_mock_agent(self, mock_event_emitter):
|
|
||||||
"""创建带模拟 Agent 的 Analysis 节点"""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(return_value=MagicMock(
|
|
||||||
success=True,
|
|
||||||
data={
|
|
||||||
"findings": [
|
|
||||||
{
|
|
||||||
"id": "finding-1",
|
|
||||||
"title": "SQL Injection",
|
|
||||||
"severity": "high",
|
|
||||||
"vulnerability_type": "sql_injection",
|
|
||||||
"file_path": "src/sql_vuln.py",
|
|
||||||
"line_start": 10,
|
|
||||||
"description": "SQL injection vulnerability",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"should_continue": False,
|
|
||||||
}
|
|
||||||
))
|
|
||||||
|
|
||||||
return AnalysisNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
|
|
||||||
"""测试 Analysis 节点成功执行"""
|
|
||||||
state = {
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"tech_stack": {"languages": ["Python"]},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": ["src/sql_vuln.py"],
|
|
||||||
"config": {},
|
|
||||||
"iteration": 0,
|
|
||||||
"findings": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await analysis_node_with_mock_agent(state)
|
|
||||||
|
|
||||||
assert "findings" in result
|
|
||||||
assert len(result["findings"]) > 0
|
|
||||||
assert result["iteration"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
class TestIntegrationFlow:
|
|
||||||
"""完整流程集成测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
|
|
||||||
"""测试完整审计流程(使用模拟)"""
|
|
||||||
# 这个测试验证整个流程的连接性
|
|
||||||
|
|
||||||
# 创建事件管理器
|
|
||||||
event_manager = EventManager()
|
|
||||||
emitter = AgentEventEmitter(mock_task.id, event_manager)
|
|
||||||
|
|
||||||
# 模拟 LLM 服务
|
|
||||||
mock_llm = MagicMock()
|
|
||||||
mock_llm.chat_completion_raw = AsyncMock(return_value={
|
|
||||||
"content": "Analysis complete",
|
|
||||||
"usage": {"total_tokens": 100},
|
|
||||||
})
|
|
||||||
|
|
||||||
# 验证事件发射
|
|
||||||
await emitter.emit_phase_start("init", "初始化")
|
|
||||||
await emitter.emit_info("测试消息")
|
|
||||||
await emitter.emit_phase_complete("init", "初始化完成")
|
|
||||||
|
|
||||||
assert emitter._sequence == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_audit_state_typing(self):
|
|
||||||
"""测试审计状态类型定义"""
|
|
||||||
state: AuditState = {
|
|
||||||
"project_root": "/tmp/test",
|
|
||||||
"project_info": {"name": "Test"},
|
|
||||||
"config": {},
|
|
||||||
"task_id": "test-id",
|
|
||||||
"tech_stack": {},
|
|
||||||
"entry_points": [],
|
|
||||||
"high_risk_areas": [],
|
|
||||||
"dependencies": {},
|
|
||||||
"findings": [],
|
|
||||||
"verified_findings": [],
|
|
||||||
"false_positives": [],
|
|
||||||
"current_phase": "start",
|
|
||||||
"iteration": 0,
|
|
||||||
"max_iterations": 50,
|
|
||||||
"should_continue_analysis": False,
|
|
||||||
"messages": [],
|
|
||||||
"events": [],
|
|
||||||
"summary": None,
|
|
||||||
"security_score": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert state["current_phase"] == "start"
|
|
||||||
assert state["max_iterations"] == 50
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolIntegration:
|
|
||||||
"""工具集成测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tools_work_together(self, temp_project_dir):
|
|
||||||
"""测试工具协同工作"""
|
|
||||||
from app.services.agent.tools import (
|
|
||||||
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 列出文件
|
|
||||||
list_tool = ListFilesTool(temp_project_dir)
|
|
||||||
list_result = await list_tool.execute(directory="src", recursive=False)
|
|
||||||
assert list_result.success is True
|
|
||||||
|
|
||||||
# 2. 搜索关键代码
|
|
||||||
search_tool = FileSearchTool(temp_project_dir)
|
|
||||||
search_result = await search_tool.execute(keyword="execute")
|
|
||||||
assert search_result.success is True
|
|
||||||
|
|
||||||
# 3. 读取文件内容
|
|
||||||
read_tool = FileReadTool(temp_project_dir)
|
|
||||||
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
|
|
||||||
assert read_result.success is True
|
|
||||||
|
|
||||||
# 4. 模式匹配
|
|
||||||
pattern_tool = PatternMatchTool(temp_project_dir)
|
|
||||||
pattern_result = await pattern_tool.execute(
|
|
||||||
code=read_result.data,
|
|
||||||
file_path="src/sql_vuln.py",
|
|
||||||
language="python"
|
|
||||||
)
|
|
||||||
assert pattern_result.success is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
|
||||||
"""错误处理测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_error_handling(self, temp_project_dir):
|
|
||||||
"""测试工具错误处理"""
|
|
||||||
from app.services.agent.tools import FileReadTool
|
|
||||||
|
|
||||||
tool = FileReadTool(temp_project_dir)
|
|
||||||
|
|
||||||
# 尝试读取不存在的文件
|
|
||||||
result = await tool.execute(file_path="nonexistent/file.py")
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert result.error is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_agent_graceful_degradation(self, mock_event_emitter):
|
|
||||||
"""测试 Agent 优雅降级"""
|
|
||||||
# 创建一个会失败的 Agent
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
|
|
||||||
|
|
||||||
node = ReconNode(mock_agent, mock_event_emitter)
|
|
||||||
|
|
||||||
result = await node({
|
|
||||||
"project_info": {},
|
|
||||||
"config": {},
|
|
||||||
})
|
|
||||||
|
|
||||||
# 应该返回错误状态而不是崩溃
|
|
||||||
assert "error" in result
|
|
||||||
assert result["current_phase"] == "error"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerformance:
|
|
||||||
"""性能测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_response_time(self, temp_project_dir):
|
|
||||||
"""测试工具响应时间"""
|
|
||||||
from app.services.agent.tools import ListFilesTool
|
|
||||||
import time
|
|
||||||
|
|
||||||
tool = ListFilesTool(temp_project_dir)
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
await tool.execute(directory=".", recursive=True)
|
|
||||||
duration = time.time() - start
|
|
||||||
|
|
||||||
# 工具应该在合理时间内响应
|
|
||||||
assert duration < 5.0 # 5 秒内
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_tool_calls(self, temp_project_dir):
|
|
||||||
"""测试多次工具调用"""
|
|
||||||
from app.services.agent.tools import FileSearchTool
|
|
||||||
|
|
||||||
tool = FileSearchTool(temp_project_dir)
|
|
||||||
|
|
||||||
# 执行多次调用
|
|
||||||
for _ in range(5):
|
|
||||||
result = await tool.execute(keyword="def")
|
|
||||||
assert result.success is True
|
|
||||||
|
|
||||||
# 验证调用计数
|
|
||||||
assert tool._call_count == 5
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue