refactor(agent): 移除LangGraph工作流并迁移到动态Agent树架构

重构Agent服务架构,从基于LangGraph的状态图迁移到动态Agent树结构。主要变更包括:
- 删除graph模块及相关测试
- 更新agent/__init__.py导入和文档
- 在projects端点添加对新AgentTask模型的统计支持
- 简化工作流描述为START→Orchestrator→[Recon/Analysis/Verification]→Report→END

新架构使用OrchestratorAgent作为编排层,动态调度子Agent完成任务,提高灵活性和可扩展性。
This commit is contained in:
lintsinghua 2025-12-25 17:58:14 +08:00
parent 39e2f43210
commit 15605fea16
7 changed files with 56 additions and 2708 deletions

View File

@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal
from app.models.project import Project from app.models.project import Project
from app.models.user import User from app.models.user import User
from app.models.audit import AuditTask, AuditIssue from app.models.audit import AuditTask, AuditIssue
from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding
from app.models.user_config import UserConfig from app.models.user_config import UserConfig
import zipfile import zipfile
from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file
@ -162,26 +163,51 @@ async def get_stats(
projects = projects_result.scalars().all() projects = projects_result.scalars().all()
project_ids = [p.id for p in projects] project_ids = [p.id for p in projects]
# 只统计当前用户项目的任务 # 统计旧的 AuditTask
tasks_result = await db.execute( tasks_result = await db.execute(
select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False) select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False)
) )
tasks = tasks_result.scalars().all() tasks = tasks_result.scalars().all()
task_ids = [t.id for t in tasks] task_ids = [t.id for t in tasks]
# 只统计当前用户任务的问题 # 统计旧的 AuditIssue
issues_result = await db.execute( issues_result = await db.execute(
select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False) select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False)
) )
issues = issues_result.scalars().all() issues = issues_result.scalars().all()
# 🔥 同时统计新的 AgentTask
agent_tasks_result = await db.execute(
select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False)
)
agent_tasks = agent_tasks_result.scalars().all()
agent_task_ids = [t.id for t in agent_tasks]
# 🔥 统计 AgentFinding
agent_findings_result = await db.execute(
select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False)
)
agent_findings = agent_findings_result.scalars().all()
# 合并统计(旧任务 + 新 Agent 任务)
total_tasks = len(tasks) + len(agent_tasks)
completed_tasks = (
len([t for t in tasks if t.status == "completed"]) +
len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED])
)
total_issues = len(issues) + len(agent_findings)
resolved_issues = (
len([i for i in issues if i.status == "resolved"]) +
len([f for f in agent_findings if f.status == "resolved"])
)
return { return {
"total_projects": len(projects), "total_projects": len(projects),
"active_projects": len([p for p in projects if p.is_active]), "active_projects": len([p for p in projects if p.is_active]),
"total_tasks": len(tasks), "total_tasks": total_tasks,
"completed_tasks": len([t for t in tasks if t.status == "completed"]), "completed_tasks": completed_tasks,
"total_issues": len(issues), "total_issues": total_issues,
"resolved_issues": len([i for i in issues if i.status == "resolved"]), "resolved_issues": resolved_issues,
} }
@router.get("/{id}", response_model=ProjectResponse) @router.get("/{id}", response_model=ProjectResponse)

View File

@ -1,29 +1,19 @@
""" """
DeepAudit Agent 服务模块 DeepAudit Agent 服务模块
基于 LangGraph AI Agent 代码安全审计 基于动态 Agent 树架构的 AI 代码安全审计
架构升级版本 - 支持 架构:
- 动态Agent树结构 - OrchestratorAgent 作为编排层动态调度子 Agent
- 专业知识模块系统 - ReconAgent 负责侦察和文件分析
- Agent间通信机制 - AnalysisAgent 负责漏洞分析
- 完整状态管理 - VerificationAgent 负责验证发现
- Think工具和漏洞报告工具
工作流: 工作流:
START Recon Analysis Verification Report END START Orchestrator [Recon/Analysis/Verification] Report END
支持动态创建子Agent进行专业化分析 支持动态创建子Agent进行专业化分析
""" """
# 从 graph 模块导入主要组件
from .graph import (
AgentRunner,
run_agent_task,
LLMService,
AuditState,
create_audit_graph,
)
# 事件管理 # 事件管理
from .event_manager import EventManager, AgentEventEmitter from .event_manager import EventManager, AgentEventEmitter
@ -33,14 +23,14 @@ from .agents import (
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
) )
# 🔥 新增:核心模块(状态管理、注册表、消息) # 核心模块(状态管理、注册表、消息)
from .core import ( from .core import (
AgentState, AgentStatus, AgentState, AgentStatus,
AgentRegistry, agent_registry, AgentRegistry, agent_registry,
AgentMessage, MessageType, MessagePriority, MessageBus, AgentMessage, MessageType, MessagePriority, MessageBus,
) )
# 🔥 新增:知识模块系统基于RAG # 知识模块系统基于RAG
from .knowledge import ( from .knowledge import (
KnowledgeLoader, knowledge_loader, KnowledgeLoader, knowledge_loader,
get_available_modules, get_module_content, get_available_modules, get_module_content,
@ -48,7 +38,7 @@ from .knowledge import (
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
) )
# 🔥 新增:协作工具 # 协作工具
from .tools import ( from .tools import (
ThinkTool, ReflectTool, ThinkTool, ReflectTool,
CreateVulnerabilityReportTool, CreateVulnerabilityReportTool,
@ -57,20 +47,11 @@ from .tools import (
WaitForMessageTool, AgentFinishTool, WaitForMessageTool, AgentFinishTool,
) )
# 🔥 新增:遥测模块 # 遥测模块
from .telemetry import Tracer, get_global_tracer, set_global_tracer from .telemetry import Tracer, get_global_tracer, set_global_tracer
__all__ = [ __all__ = [
# 核心 Runner
"AgentRunner",
"run_agent_task",
"LLMService",
# LangGraph
"AuditState",
"create_audit_graph",
# 事件管理 # 事件管理
"EventManager", "EventManager",
"AgentEventEmitter", "AgentEventEmitter",
@ -84,7 +65,7 @@ __all__ = [
"AnalysisAgent", "AnalysisAgent",
"VerificationAgent", "VerificationAgent",
# 🔥 核心模块 # 核心模块
"AgentState", "AgentState",
"AgentStatus", "AgentStatus",
"AgentRegistry", "AgentRegistry",
@ -94,7 +75,7 @@ __all__ = [
"MessagePriority", "MessagePriority",
"MessageBus", "MessageBus",
# 🔥 知识模块基于RAG # 知识模块基于RAG
"KnowledgeLoader", "KnowledgeLoader",
"knowledge_loader", "knowledge_loader",
"get_available_modules", "get_available_modules",
@ -104,7 +85,7 @@ __all__ = [
"SecurityKnowledgeQueryTool", "SecurityKnowledgeQueryTool",
"GetVulnerabilityKnowledgeTool", "GetVulnerabilityKnowledgeTool",
# 🔥 协作工具 # 协作工具
"ThinkTool", "ThinkTool",
"ReflectTool", "ReflectTool",
"CreateVulnerabilityReportTool", "CreateVulnerabilityReportTool",
@ -115,9 +96,8 @@ __all__ = [
"WaitForMessageTool", "WaitForMessageTool",
"AgentFinishTool", "AgentFinishTool",
# 🔥 遥测模块 # 遥测模块
"Tracer", "Tracer",
"get_global_tracer", "get_global_tracer",
"set_global_tracer", "set_global_tracer",
] ]

View File

@ -1,28 +0,0 @@
"""
LangGraph 工作流模块
使用状态图构建混合 Agent 审计流程
"""
from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode
from .runner import AgentRunner, run_agent_task, LLMService
__all__ = [
# 状态和图
"AuditState",
"create_audit_graph",
"create_audit_graph_with_human",
# 节点
"ReconNode",
"AnalysisNode",
"VerificationNode",
"ReportNode",
"HumanReviewNode",
# Runner
"AgentRunner",
"run_agent_task",
"LLMService",
]

View File

@ -1,677 +0,0 @@
"""
DeepAudit 审计工作流图 - LLM 驱动版
使用 LangGraph 构建 LLM 驱动的 Agent 协作流程
重要改变路由决策由 LLM 参与而不是硬编码条件
"""
from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal
from datetime import datetime
import operator
import logging
import json
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
# ============ 状态定义 ============
class Finding(TypedDict):
"""漏洞发现"""
id: str
vulnerability_type: str
severity: str
title: str
description: str
file_path: Optional[str]
line_start: Optional[int]
code_snippet: Optional[str]
is_verified: bool
confidence: float
source: str
class AuditState(TypedDict):
"""
审计状态
在整个工作流中传递和更新
"""
# 输入
project_root: str
project_info: Dict[str, Any]
config: Dict[str, Any]
task_id: str
# Recon 阶段输出
tech_stack: Dict[str, Any]
entry_points: List[Dict[str, Any]]
high_risk_areas: List[str]
dependencies: Dict[str, Any]
# Analysis 阶段输出
findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现
# Verification 阶段输出
verified_findings: List[Finding]
false_positives: List[str]
# 🔥 NEW: 验证后的完整 findings用于替换原始 findings
_verified_findings_update: Optional[List[Finding]]
# 控制流 - 🔥 关键LLM 可以设置这些来影响路由
current_phase: str
iteration: int
max_iterations: int
should_continue_analysis: bool
# 🔥 新增LLM 的路由决策
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
llm_routing_reason: Optional[str] # LLM 的决策理由
# 🔥 新增Agent 间协作的任务交接信息
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
# 消息和事件
messages: Annotated[List[Dict], operator.add]
events: Annotated[List[Dict], operator.add]
# 最终输出
summary: Optional[Dict[str, Any]]
security_score: Optional[int]
error: Optional[str]
# ============ LLM 路由决策器 ============
class LLMRouter:
"""
LLM 路由决策器
LLM 来决定下一步应该做什么
"""
def __init__(self, llm_service):
self.llm_service = llm_service
async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]:
"""Recon 后让 LLM 决定下一步"""
entry_points = state.get("entry_points", [])
high_risk_areas = state.get("high_risk_areas", [])
tech_stack = state.get("tech_stack", {})
initial_findings = state.get("findings", [])
prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。
## 信息收集结果
- 入口点数量: {len(entry_points)}
- 高风险区域: {high_risk_areas[:10]}
- 技术栈: {tech_stack}
- 初步发现: {len(initial_findings)}
## 选项
1. "analysis" - 继续进行漏洞分析推荐有入口点或高风险区域时
2. "end" - 结束审计仅当没有任何可分析内容时
请返回 JSON 格式
{{"action": "analysis或end", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
# 提取 JSON
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if entry_points or high_risk_areas:
return {"action": "analysis", "reason": "有可分析内容"}
return {"action": "end", "reason": "没有发现入口点或高风险区域"}
async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]:
"""Analysis 后让 LLM 决定下一步"""
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
# 统计发现
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in findings:
# 跳过非字典类型的 finding
if not isinstance(f, dict):
continue
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。
## 分析结果
- 总发现数: {len(findings)}
- 严重程度分布: {severity_counts}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "verification" - 验证发现的漏洞推荐有发现需要验证时
2. "analysis" - 继续深入分析推荐发现较少但还有迭代次数时
3. "report" - 生成报告推荐没有发现或已充分分析时
请返回 JSON 格式
{{"action": "verification/analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if not findings:
return {"action": "report", "reason": "没有发现漏洞"}
if len(findings) >= 3 or iteration >= max_iterations:
return {"action": "verification", "reason": "有足够的发现需要验证"}
return {"action": "analysis", "reason": "发现较少,继续分析"}
async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]:
"""Verification 后让 LLM 决定下一步"""
verified_findings = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。
## 验证结果
- 已确认漏洞: {len(verified_findings)}
- 误报数量: {len(false_positives)}
- 当前迭代: {iteration}/{max_iterations}
## 选项
1. "analysis" - 回到分析阶段重新分析推荐误报率太高时
2. "report" - 生成最终报告推荐验证完成时
请返回 JSON 格式
{{"action": "analysis/report", "reason": "决策理由"}}"""
try:
response = await self.llm_service.chat_completion_raw(
messages=[
{"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"},
{"role": "user", "content": prompt},
],
# 🔥 不传递 temperature 和 max_tokens使用用户配置
)
content = response.get("content", "")
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except Exception as e:
logger.warning(f"LLM routing decision failed: {e}")
# 默认决策
if len(false_positives) > len(verified_findings) and iteration < max_iterations:
return {"action": "analysis", "reason": "误报率较高,需要重新分析"}
return {"action": "report", "reason": "验证完成,生成报告"}
# ============ 路由函数 (结合 LLM 决策) ============
def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
"""
Recon 后的路由决策
优先使用 LLM 的决策否则使用默认逻辑
"""
# 🔥 检查是否有错误
if state.get("error") or state.get("current_phase") == "error":
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
return "end"
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "end":
return "end"
return "analysis"
# 默认逻辑(作为 fallback
if not state.get("entry_points") and not state.get("high_risk_areas"):
return "end"
return "analysis"
def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]:
"""
Analysis 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "verification":
return "verification"
elif llm_action == "analysis":
return "analysis"
elif llm_action == "report":
return "report"
# 默认逻辑
findings = state.get("findings", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
should_continue = state.get("should_continue_analysis", False)
if not findings:
return "report"
if should_continue and iteration < max_iterations:
return "analysis"
return "verification"
def route_after_verification(state: AuditState) -> Literal["analysis", "report"]:
"""
Verification 后的路由决策
优先使用 LLM 的决策
"""
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:
logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}")
if llm_action == "analysis":
return "analysis"
return "report"
# 默认逻辑
false_positives = state.get("false_positives", [])
iteration = state.get("iteration", 0)
max_iterations = state.get("max_iterations", 3)
if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations:
return "analysis"
return "report"
# ============ 创建审计图 ============
def create_audit_graph(
recon_node,
analysis_node,
verification_node,
report_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None, # 用于 LLM 路由决策
) -> StateGraph:
"""
创建审计工作流图
Args:
recon_node: 信息收集节点
analysis_node: 漏洞分析节点
verification_node: 漏洞验证节点
report_node: 报告生成节点
checkpointer: 检查点存储器
llm_service: LLM 服务用于路由决策
Returns:
编译后的 StateGraph
工作流结构:
START
Recon 信息收集 (LLM 驱动)
LLM 决定
Analysis 漏洞分析 (LLM 驱动可循环)
LLM 决定
Verification 漏洞验证 (LLM 驱动可回溯)
LLM 决定
Report 报告生成
END
"""
# 创建状态图
workflow = StateGraph(AuditState)
# 如果有 LLM 服务,创建路由决策器
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点以添加 LLM 路由决策
async def recon_with_routing(state):
result = await recon_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def verification_with_routing(state):
result = await verification_node(state)
# LLM 决定下一步
if llm_router:
decision = await llm_router.decide_after_verification({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
workflow.add_node("verification", verification_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("report", report_node)
# 设置入口点
workflow.set_entry_point("recon")
# 添加条件边
workflow.add_conditional_edges(
"recon",
route_after_recon,
{
"analysis": "analysis",
"end": END,
}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis",
"report": "report",
}
)
workflow.add_conditional_edges(
"verification",
route_after_verification,
{
"analysis": "analysis",
"report": "report",
}
)
# Report -> END
workflow.add_edge("report", END)
# 编译图
if checkpointer:
return workflow.compile(checkpointer=checkpointer)
else:
return workflow.compile()
# ============ 带人机协作的审计图 ============
def create_audit_graph_with_human(
recon_node,
analysis_node,
verification_node,
report_node,
human_review_node,
checkpointer: Optional[MemorySaver] = None,
llm_service=None,
) -> StateGraph:
"""
创建带人机协作的审计工作流图
在验证阶段后增加人工审核节点
"""
workflow = StateGraph(AuditState)
llm_router = LLMRouter(llm_service) if llm_service else None
# 包装节点
async def recon_with_routing(state):
result = await recon_node(state)
if llm_router:
decision = await llm_router.decide_after_recon({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
async def analysis_with_routing(state):
result = await analysis_node(state)
if llm_router:
decision = await llm_router.decide_after_analysis({**state, **result})
result["llm_next_action"] = decision.get("action")
result["llm_routing_reason"] = decision.get("reason")
return result
# 添加节点
if llm_router:
workflow.add_node("recon", recon_with_routing)
workflow.add_node("analysis", analysis_with_routing)
else:
workflow.add_node("recon", recon_node)
workflow.add_node("analysis", analysis_node)
workflow.add_node("verification", verification_node)
workflow.add_node("human_review", human_review_node)
workflow.add_node("report", report_node)
workflow.set_entry_point("recon")
workflow.add_conditional_edges(
"recon",
route_after_recon,
{"analysis": "analysis", "end": END}
)
workflow.add_conditional_edges(
"analysis",
route_after_analysis,
{
"verification": "verification",
"analysis": "analysis",
"report": "report",
}
)
# Verification -> Human Review
workflow.add_edge("verification", "human_review")
# Human Review 后的路由
def route_after_human(state: AuditState) -> Literal["analysis", "report"]:
if state.get("should_continue_analysis"):
return "analysis"
return "report"
workflow.add_conditional_edges(
"human_review",
route_after_human,
{"analysis": "analysis", "report": "report"}
)
workflow.add_edge("report", END)
if checkpointer:
return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"])
else:
return workflow.compile()
# ============ 执行器 ============
class AuditGraphRunner:
"""
审计图执行器
封装 LangGraph 工作流的执行
"""
def __init__(
self,
graph: StateGraph,
event_emitter=None,
):
self.graph = graph
self.event_emitter = event_emitter
async def run(
self,
project_root: str,
project_info: Dict[str, Any],
config: Dict[str, Any],
task_id: str,
) -> Dict[str, Any]:
"""
执行审计工作流
"""
# 初始状态
initial_state: AuditState = {
"project_root": project_root,
"project_info": project_info,
"config": config,
"task_id": task_id,
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
"current_phase": "start",
"iteration": 0,
"max_iterations": config.get("max_iterations", 3),
"should_continue_analysis": False,
"llm_next_action": None,
"llm_routing_reason": None,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
run_config = {
"configurable": {
"thread_id": task_id,
}
}
try:
async for event in self.graph.astream(initial_state, config=run_config):
if self.event_emitter:
for node_name, node_state in event.items():
await self.event_emitter.emit_info(
f"节点 {node_name} 完成"
)
# 发射 LLM 路由决策事件
if node_state.get("llm_routing_reason"):
await self.event_emitter.emit_info(
f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}"
)
if node_name == "analysis" and node_state.get("findings"):
new_findings = node_state["findings"]
await self.event_emitter.emit_info(
f"发现 {len(new_findings)} 个潜在漏洞"
)
final_state = self.graph.get_state(run_config)
return final_state.values
except Exception as e:
logger.error(f"Graph execution failed: {e}", exc_info=True)
raise
async def run_with_human_review(
self,
initial_state: AuditState,
human_feedback_callback,
) -> Dict[str, Any]:
"""带人机协作的执行"""
run_config = {
"configurable": {
"thread_id": initial_state["task_id"],
}
}
async for event in self.graph.astream(initial_state, config=run_config):
pass
current_state = self.graph.get_state(run_config)
if current_state.next == ("human_review",):
human_decision = await human_feedback_callback(current_state.values)
updated_state = {
**current_state.values,
"should_continue_analysis": human_decision.get("continue_analysis", False),
}
async for event in self.graph.astream(updated_state, config=run_config):
pass
return self.graph.get_state(run_config).values

View File

@ -1,556 +0,0 @@
"""
LangGraph 节点实现
每个节点封装一个 Agent 的执行逻辑
协作增强节点之间通过 TaskHandoff 传递结构化的上下文和洞察
"""
from typing import Dict, Any, List, Optional
import logging
logger = logging.getLogger(__name__)
# 延迟导入避免循环依赖
def get_audit_state_type():
from .audit_graph import AuditState
return AuditState
class BaseNode:
"""节点基类"""
def __init__(self, agent=None, event_emitter=None):
self.agent = agent
self.event_emitter = event_emitter
async def emit_event(self, event_type: str, message: str, **kwargs):
"""发射事件"""
if self.event_emitter:
try:
await self.event_emitter.emit_info(message)
except Exception as e:
logger.warning(f"Failed to emit event: {e}")
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
"""从状态中提取前序 Agent 的 handoff"""
handoff_data = state.get(f"{from_phase}_handoff")
if handoff_data:
from ..agents.base import TaskHandoff
return TaskHandoff.from_dict(handoff_data)
return None
class ReconNode(BaseNode):
"""
信息收集节点
输入: project_root, project_info, config
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行信息收集"""
await self.emit_event("phase_start", "🔍 开始信息收集阶段")
try:
# 调用 Recon Agent
result = await self.agent.run({
"project_info": state["project_info"],
"config": state["config"],
})
if result.success and result.data:
data = result.data
# 🔥 创建交接信息给 Analysis Agent
handoff = self.agent.create_handoff(
to_agent="Analysis",
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
key_findings=data.get("initial_findings", []),
suggested_actions=[
{
"type": "deep_analysis",
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
"priority": "high",
},
{
"type": "entry_point_audit",
"description": "审计所有入口点的输入验证",
"priority": "high",
},
],
attention_points=[
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
],
priority_areas=data.get("high_risk_areas", [])[:10],
context_data={
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"dependencies": data.get("dependencies", {}),
},
)
await self.emit_event(
"phase_complete",
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
)
return {
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"high_risk_areas": data.get("high_risk_areas", []),
"dependencies": data.get("dependencies", {}),
"current_phase": "recon_complete",
"findings": data.get("initial_findings", []),
# 🔥 保存交接信息
"recon_handoff": handoff.to_dict(),
"events": [{
"type": "recon_complete",
"data": {
"entry_points_count": len(data.get("entry_points", [])),
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"error": result.error or "Recon failed",
"current_phase": "error",
}
except Exception as e:
logger.error(f"Recon node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class AnalysisNode(BaseNode):
"""
漏洞分析节点
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
输出: findings (累加), should_continue_analysis, analysis_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞分析"""
iteration = state.get("iteration", 0) + 1
await self.emit_event(
"phase_start",
f"🔬 开始漏洞分析阶段 (迭代 {iteration})"
)
try:
# 🔥 提取 Recon 的交接信息
recon_handoff = self._extract_handoff_from_state(state, "recon")
if recon_handoff:
self.agent.receive_handoff(recon_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
)
# 构建分析输入
analysis_input = {
"phase_name": "analysis",
"project_info": state["project_info"],
"config": state["config"],
"plan": {
"high_risk_areas": state.get("high_risk_areas", []),
},
"previous_results": {
"recon": {
"data": {
"tech_stack": state.get("tech_stack", {}),
"entry_points": state.get("entry_points", []),
"high_risk_areas": state.get("high_risk_areas", []),
}
}
},
# 🔥 传递交接信息
"handoff": recon_handoff,
}
# 调用 Analysis Agent
result = await self.agent.run(analysis_input)
if result.success and result.data:
new_findings = result.data.get("findings", [])
logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings")
# 判断是否需要继续分析
should_continue = (
len(new_findings) >= 5 and
iteration < state.get("max_iterations", 3)
)
# 🔥 创建交接信息给 Verification Agent
# 统计严重程度
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in new_findings:
if isinstance(f, dict):
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
handoff = self.agent.create_handoff(
to_agent="Verification",
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
key_findings=new_findings[:20], # 传递前20个发现
suggested_actions=[
{
"type": "verify_critical",
"description": "优先验证 Critical 和 High 级别的漏洞",
"priority": "critical",
},
{
"type": "poc_generation",
"description": "为确认的漏洞生成 PoC",
"priority": "high",
},
],
attention_points=[
f"{severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
f"{severity_counts['high']} 个 High 级别漏洞需要优先验证",
"注意检查是否有误报,特别是静态分析工具的结果",
],
priority_areas=[
f.get("file_path", "") for f in new_findings
if f.get("severity") in ["critical", "high"]
][:10],
context_data={
"severity_distribution": severity_counts,
"total_findings": len(new_findings),
"iteration": iteration,
},
)
await self.emit_event(
"phase_complete",
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
)
return {
"findings": new_findings,
"iteration": iteration,
"should_continue_analysis": should_continue,
"current_phase": "analysis_complete",
# 🔥 保存交接信息
"analysis_handoff": handoff.to_dict(),
"events": [{
"type": "analysis_iteration",
"data": {
"iteration": iteration,
"findings_count": len(new_findings),
"severity_distribution": severity_counts,
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"iteration": iteration,
"should_continue_analysis": False,
"current_phase": "analysis_complete",
}
except Exception as e:
logger.error(f"Analysis node failed: {e}", exc_info=True)
return {
"error": str(e),
"should_continue_analysis": False,
"current_phase": "error",
}
class VerificationNode(BaseNode):
"""
漏洞验证节点
输入: findings, analysis_handoff
输出: verified_findings, false_positives, verification_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""执行漏洞验证"""
findings = state.get("findings", [])
logger.info(f"[VerificationNode] Received {len(findings)} findings to verify")
if not findings:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
await self.emit_event(
"phase_start",
f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)"
)
try:
# 🔥 提取 Analysis 的交接信息
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
if analysis_handoff:
self.agent.receive_handoff(analysis_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
)
# 构建验证输入
verification_input = {
"previous_results": {
"analysis": {
"data": {
"findings": findings,
}
}
},
"config": state["config"],
# 🔥 传递交接信息
"handoff": analysis_handoff,
}
# 调用 Verification Agent
result = await self.agent.run(verification_input)
if result.success and result.data:
all_verified_findings = result.data.get("findings", [])
verified = [f for f in all_verified_findings if f.get("is_verified")]
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
if f.get("verdict") == "false_positive"]
# 🔥 CRITICAL FIX: 用验证结果更新原始 findings
# 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type)
verified_map = {}
for vf in all_verified_findings:
key = (
vf.get("file_path", ""),
vf.get("line_start", 0),
vf.get("vulnerability_type", ""),
)
verified_map[key] = vf
# 合并验证结果到原始 findings
updated_findings = []
seen_keys = set()
# 首先处理原始 findings用验证结果更新
for f in findings:
if not isinstance(f, dict):
continue
key = (
f.get("file_path", ""),
f.get("line_start", 0),
f.get("vulnerability_type", ""),
)
if key in verified_map:
# 使用验证后的版本
updated_findings.append(verified_map[key])
seen_keys.add(key)
else:
# 保留原始(未验证)
updated_findings.append(f)
seen_keys.add(key)
# 添加验证结果中的新发现(如果有)
for key, vf in verified_map.items():
if key not in seen_keys:
updated_findings.append(vf)
logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified")
# 🔥 创建交接信息给 Report 节点
handoff = self.agent.create_handoff(
to_agent="Report",
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
key_findings=verified,
suggested_actions=[
{
"type": "generate_report",
"description": "生成详细的安全审计报告",
"priority": "high",
},
{
"type": "remediation_plan",
"description": "为确认的漏洞制定修复计划",
"priority": "high",
},
],
attention_points=[
f"{len(verified)} 个漏洞已确认存在",
f"{len(false_pos)} 个误报已排除",
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
],
context_data={
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"total_analyzed": len(findings),
"verification_rate": len(verified) / len(findings) if findings else 0,
},
)
await self.emit_event(
"phase_complete",
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
)
return {
# 🔥 CRITICAL: 返回更新后的 findings这会替换状态中的 findings
# 注意:由于 LangGraph 使用 operator.add我们需要在 runner 中处理合并
# 这里我们返回 _verified_findings_update 作为特殊字段
"_verified_findings_update": updated_findings,
"verified_findings": verified,
"false_positives": false_pos,
"current_phase": "verification_complete",
# 🔥 保存交接信息
"verification_handoff": handoff.to_dict(),
"events": [{
"type": "verification_complete",
"data": {
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"total_findings": len(updated_findings),
"handoff_summary": handoff.summary,
}
}],
}
else:
return {
"verified_findings": [],
"false_positives": [],
"current_phase": "verification_complete",
}
except Exception as e:
logger.error(f"Verification node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class ReportNode(BaseNode):
"""
报告生成节点
输入: all state
输出: summary, security_score
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""生成审计报告"""
await self.emit_event("phase_start", "📊 生成审计报告")
try:
# 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新
findings = state.get("_verified_findings_update") or state.get("findings", [])
verified = state.get("verified_findings", [])
false_positives = state.get("false_positives", [])
logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified")
# 统计漏洞分布
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
type_counts = {}
for finding in findings:
# 跳过非字典类型的 finding防止数据格式异常
if not isinstance(finding, dict):
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
continue
sev = finding.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
vtype = finding.get("vulnerability_type", "other")
type_counts[vtype] = type_counts.get(vtype, 0) + 1
# 计算安全评分
base_score = 100
deductions = (
severity_counts["critical"] * 25 +
severity_counts["high"] * 15 +
severity_counts["medium"] * 8 +
severity_counts["low"] * 3
)
security_score = max(0, base_score - deductions)
# 生成摘要
summary = {
"total_findings": len(findings),
"verified_count": len(verified),
"false_positive_count": len(false_positives),
"severity_distribution": severity_counts,
"vulnerability_types": type_counts,
"tech_stack": state.get("tech_stack", {}),
"entry_points_analyzed": len(state.get("entry_points", [])),
"high_risk_areas": state.get("high_risk_areas", []),
"iterations": state.get("iteration", 1),
}
await self.emit_event(
"phase_complete",
f"报告生成完成: 安全评分 {security_score}/100"
)
return {
"summary": summary,
"security_score": security_score,
"current_phase": "complete",
"events": [{
"type": "audit_complete",
"data": {
"security_score": security_score,
"total_findings": len(findings),
"verified_count": len(verified),
}
}],
}
except Exception as e:
logger.error(f"Report node failed: {e}", exc_info=True)
return {
"error": str(e),
"current_phase": "error",
}
class HumanReviewNode(BaseNode):
"""
人工审核节点
在此节点暂停等待人工反馈
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
人工审核节点
这个节点会被 interrupt_before 暂停
用户可以
1. 确认发现
2. 标记误报
3. 请求重新分析
"""
await self.emit_event(
"human_review",
f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)"
)
# 返回当前状态,不做修改
# 人工反馈会通过 update_state 传入
return {
"current_phase": "human_review",
"messages": [{
"role": "system",
"content": "等待人工审核",
"findings_for_review": state.get("verified_findings", []),
}],
}

File diff suppressed because it is too large Load Diff

View File

@ -1,355 +0,0 @@
"""
Agent 集成测试
测试完整的审计流程
"""
import pytest
import asyncio
import os
from unittest.mock import MagicMock, AsyncMock, patch
from datetime import datetime
from app.services.agent.graph.runner import AgentRunner, LLMService
from app.services.agent.graph.audit_graph import AuditState, create_audit_graph
from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
from app.services.agent.event_manager import EventManager, AgentEventEmitter
class TestLLMService:
"""LLM 服务测试"""
@pytest.mark.asyncio
async def test_llm_service_initialization(self):
"""测试 LLM 服务初始化"""
with patch("app.core.config.settings") as mock_settings:
mock_settings.LLM_MODEL = "gpt-4o-mini"
mock_settings.LLM_API_KEY = "test-key"
service = LLMService()
assert service.model == "gpt-4o-mini"
class TestEventManager:
"""事件管理器测试"""
def test_event_manager_initialization(self):
"""测试事件管理器初始化"""
manager = EventManager()
assert manager._event_queues == {}
assert manager._event_callbacks == {}
@pytest.mark.asyncio
async def test_event_emitter(self):
"""测试事件发射器"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_info("Test message")
assert emitter._sequence == 1
@pytest.mark.asyncio
async def test_event_emitter_phase_tracking(self):
"""测试事件发射器阶段跟踪"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_phase_start("recon", "开始信息收集")
assert emitter._current_phase == "recon"
@pytest.mark.asyncio
async def test_event_emitter_task_complete(self):
"""测试任务完成事件"""
manager = EventManager()
emitter = AgentEventEmitter("test-task-id", manager)
await emitter.emit_task_complete(findings_count=5, duration_ms=1000)
assert emitter._sequence == 1
class TestAuditGraph:
"""审计图测试"""
def test_create_audit_graph(self, mock_event_emitter):
"""测试创建审计图"""
# 创建模拟节点
recon_node = MagicMock()
analysis_node = MagicMock()
verification_node = MagicMock()
report_node = MagicMock()
graph = create_audit_graph(
recon_node=recon_node,
analysis_node=analysis_node,
verification_node=verification_node,
report_node=report_node,
)
assert graph is not None
class TestReconNode:
"""Recon 节点测试"""
@pytest.fixture
def recon_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Recon 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"tech_stack": {"languages": ["Python"]},
"entry_points": [{"path": "src/app.py", "type": "api"}],
"high_risk_areas": ["src/sql_vuln.py"],
"dependencies": {},
"initial_findings": [],
}
))
return ReconNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_recon_node_success(self, recon_node_with_mock_agent):
"""测试 Recon 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"config": {},
}
result = await recon_node_with_mock_agent(state)
assert "tech_stack" in result
assert "entry_points" in result
assert result["current_phase"] == "recon_complete"
@pytest.mark.asyncio
async def test_recon_node_failure(self, mock_event_emitter):
"""测试 Recon 节点失败处理"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=False,
error="Test error",
data=None,
))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
assert "error" in result
assert result["current_phase"] == "error"
class TestAnalysisNode:
"""Analysis 节点测试"""
@pytest.fixture
def analysis_node_with_mock_agent(self, mock_event_emitter):
"""创建带模拟 Agent 的 Analysis 节点"""
mock_agent = MagicMock()
mock_agent.run = AsyncMock(return_value=MagicMock(
success=True,
data={
"findings": [
{
"id": "finding-1",
"title": "SQL Injection",
"severity": "high",
"vulnerability_type": "sql_injection",
"file_path": "src/sql_vuln.py",
"line_start": 10,
"description": "SQL injection vulnerability",
}
],
"should_continue": False,
}
))
return AnalysisNode(mock_agent, mock_event_emitter)
@pytest.mark.asyncio
async def test_analysis_node_success(self, analysis_node_with_mock_agent):
"""测试 Analysis 节点成功执行"""
state = {
"project_info": {"name": "Test"},
"tech_stack": {"languages": ["Python"]},
"entry_points": [],
"high_risk_areas": ["src/sql_vuln.py"],
"config": {},
"iteration": 0,
"findings": [],
}
result = await analysis_node_with_mock_agent(state)
assert "findings" in result
assert len(result["findings"]) > 0
assert result["iteration"] == 1
class TestIntegrationFlow:
"""完整流程集成测试"""
@pytest.mark.asyncio
async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task):
"""测试完整审计流程(使用模拟)"""
# 这个测试验证整个流程的连接性
# 创建事件管理器
event_manager = EventManager()
emitter = AgentEventEmitter(mock_task.id, event_manager)
# 模拟 LLM 服务
mock_llm = MagicMock()
mock_llm.chat_completion_raw = AsyncMock(return_value={
"content": "Analysis complete",
"usage": {"total_tokens": 100},
})
# 验证事件发射
await emitter.emit_phase_start("init", "初始化")
await emitter.emit_info("测试消息")
await emitter.emit_phase_complete("init", "初始化完成")
assert emitter._sequence == 3
@pytest.mark.asyncio
async def test_audit_state_typing(self):
"""测试审计状态类型定义"""
state: AuditState = {
"project_root": "/tmp/test",
"project_info": {"name": "Test"},
"config": {},
"task_id": "test-id",
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"current_phase": "start",
"iteration": 0,
"max_iterations": 50,
"should_continue_analysis": False,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
assert state["current_phase"] == "start"
assert state["max_iterations"] == 50
class TestToolIntegration:
"""工具集成测试"""
@pytest.mark.asyncio
async def test_tools_work_together(self, temp_project_dir):
"""测试工具协同工作"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool,
)
# 1. 列出文件
list_tool = ListFilesTool(temp_project_dir)
list_result = await list_tool.execute(directory="src", recursive=False)
assert list_result.success is True
# 2. 搜索关键代码
search_tool = FileSearchTool(temp_project_dir)
search_result = await search_tool.execute(keyword="execute")
assert search_result.success is True
# 3. 读取文件内容
read_tool = FileReadTool(temp_project_dir)
read_result = await read_tool.execute(file_path="src/sql_vuln.py")
assert read_result.success is True
# 4. 模式匹配
pattern_tool = PatternMatchTool(temp_project_dir)
pattern_result = await pattern_tool.execute(
code=read_result.data,
file_path="src/sql_vuln.py",
language="python"
)
assert pattern_result.success is True
class TestErrorHandling:
"""错误处理测试"""
@pytest.mark.asyncio
async def test_tool_error_handling(self, temp_project_dir):
"""测试工具错误处理"""
from app.services.agent.tools import FileReadTool
tool = FileReadTool(temp_project_dir)
# 尝试读取不存在的文件
result = await tool.execute(file_path="nonexistent/file.py")
assert result.success is False
assert result.error is not None
@pytest.mark.asyncio
async def test_agent_graceful_degradation(self, mock_event_emitter):
"""测试 Agent 优雅降级"""
# 创建一个会失败的 Agent
mock_agent = MagicMock()
mock_agent.run = AsyncMock(side_effect=Exception("Simulated error"))
node = ReconNode(mock_agent, mock_event_emitter)
result = await node({
"project_info": {},
"config": {},
})
# 应该返回错误状态而不是崩溃
assert "error" in result
assert result["current_phase"] == "error"
class TestPerformance:
"""性能测试"""
@pytest.mark.asyncio
async def test_tool_response_time(self, temp_project_dir):
"""测试工具响应时间"""
from app.services.agent.tools import ListFilesTool
import time
tool = ListFilesTool(temp_project_dir)
start = time.time()
await tool.execute(directory=".", recursive=True)
duration = time.time() - start
# 工具应该在合理时间内响应
assert duration < 5.0 # 5 秒内
@pytest.mark.asyncio
async def test_multiple_tool_calls(self, temp_project_dir):
"""测试多次工具调用"""
from app.services.agent.tools import FileSearchTool
tool = FileSearchTool(temp_project_dir)
# 执行多次调用
for _ in range(5):
result = await tool.execute(keyword="def")
assert result.success is True
# 验证调用计数
assert tool._call_count == 5