From 15605fea16328be3728ef7c5544a8639be4ce512 Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Thu, 25 Dec 2025 17:58:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor(agent):=20=E7=A7=BB=E9=99=A4LangGraph?= =?UTF-8?q?=E5=B7=A5=E4=BD=9C=E6=B5=81=E5=B9=B6=E8=BF=81=E7=A7=BB=E5=88=B0?= =?UTF-8?q?=E5=8A=A8=E6=80=81Agent=E6=A0=91=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构Agent服务架构,从基于LangGraph的状态图迁移到动态Agent树结构。主要变更包括: - 删除graph模块及相关测试 - 更新agent/__init__.py导入和文档 - 在projects端点添加对新AgentTask模型的统计支持 - 简化工作流描述为START→Orchestrator→[Recon/Analysis/Verification]→Report→END 新架构使用OrchestratorAgent作为编排层,动态调度子Agent完成任务,提高灵活性和可扩展性。 --- backend/app/api/v1/endpoints/projects.py | 44 +- backend/app/services/agent/__init__.py | 62 +- backend/app/services/agent/graph/__init__.py | 28 - .../app/services/agent/graph/audit_graph.py | 677 ----------- backend/app/services/agent/graph/nodes.py | 556 --------- backend/app/services/agent/graph/runner.py | 1042 ----------------- backend/tests/agent/test_integration.py | 355 ------ 7 files changed, 56 insertions(+), 2708 deletions(-) delete mode 100644 backend/app/services/agent/graph/__init__.py delete mode 100644 backend/app/services/agent/graph/audit_graph.py delete mode 100644 backend/app/services/agent/graph/nodes.py delete mode 100644 backend/app/services/agent/graph/runner.py delete mode 100644 backend/tests/agent/test_integration.py diff --git a/backend/app/api/v1/endpoints/projects.py b/backend/app/api/v1/endpoints/projects.py index 66b0710..cca092a 100644 --- a/backend/app/api/v1/endpoints/projects.py +++ b/backend/app/api/v1/endpoints/projects.py @@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal from app.models.project import Project from app.models.user import User from app.models.audit import AuditTask, AuditIssue +from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding from app.models.user_config import UserConfig import zipfile from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file @@ -161,27 +162,52 @@ async def get_stats( ) projects = projects_result.scalars().all() project_ids = [p.id for p in projects] - - # 只统计当前用户项目的任务 + + # 统计旧的 AuditTask tasks_result = await db.execute( select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False) ) tasks = tasks_result.scalars().all() task_ids = [t.id for t in tasks] - - # 只统计当前用户任务的问题 + + # 统计旧的 AuditIssue issues_result = await db.execute( select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False) ) issues = issues_result.scalars().all() - + + # 🔥 同时统计新的 AgentTask + agent_tasks_result = await db.execute( + select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False) + ) + agent_tasks = agent_tasks_result.scalars().all() + agent_task_ids = [t.id for t in agent_tasks] + + # 🔥 统计 AgentFinding + agent_findings_result = await db.execute( + select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False) + ) + agent_findings = agent_findings_result.scalars().all() + + # 合并统计(旧任务 + 新 Agent 任务) + total_tasks = len(tasks) + len(agent_tasks) + completed_tasks = ( + len([t for t in tasks if t.status == "completed"]) + + len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED]) + ) + total_issues = len(issues) + len(agent_findings) + resolved_issues = ( + len([i for i in issues if i.status == "resolved"]) + + len([f for f in agent_findings if f.status == "resolved"]) + ) + return { "total_projects": len(projects), "active_projects": len([p for p in projects if p.is_active]), - "total_tasks": len(tasks), - "completed_tasks": len([t for t in tasks if t.status == "completed"]), - "total_issues": len(issues), - "resolved_issues": len([i for i in issues if i.status == "resolved"]), + "total_tasks": total_tasks, + "completed_tasks": completed_tasks, + "total_issues": total_issues, + "resolved_issues": resolved_issues, } @router.get("/{id}", response_model=ProjectResponse) diff --git a/backend/app/services/agent/__init__.py b/backend/app/services/agent/__init__.py index fee2a92..f169c80 100644 --- a/backend/app/services/agent/__init__.py +++ b/backend/app/services/agent/__init__.py @@ -1,29 +1,19 @@ """ DeepAudit Agent 服务模块 -基于 LangGraph 的 AI Agent 代码安全审计 +基于动态 Agent 树架构的 AI 代码安全审计 -架构升级版本 - 支持: -- 动态Agent树结构 -- 专业知识模块系统 -- Agent间通信机制 -- 完整状态管理 -- Think工具和漏洞报告工具 +架构: +- OrchestratorAgent 作为编排层,动态调度子 Agent +- ReconAgent 负责侦察和文件分析 +- AnalysisAgent 负责漏洞分析 +- VerificationAgent 负责验证发现 工作流: - START → Recon → Analysis ⟲ → Verification → Report → END - + START → Orchestrator → [Recon/Analysis/Verification] → Report → END + 支持动态创建子Agent进行专业化分析 """ -# 从 graph 模块导入主要组件 -from .graph import ( - AgentRunner, - run_agent_task, - LLMService, - AuditState, - create_audit_graph, -) - # 事件管理 from .event_manager import EventManager, AgentEventEmitter @@ -33,14 +23,14 @@ from .agents import ( OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, ) -# 🔥 新增:核心模块(状态管理、注册表、消息) +# 核心模块(状态管理、注册表、消息) from .core import ( AgentState, AgentStatus, AgentRegistry, agent_registry, AgentMessage, MessageType, MessagePriority, MessageBus, ) -# 🔥 新增:知识模块系统(基于RAG) +# 知识模块系统(基于RAG) from .knowledge import ( KnowledgeLoader, knowledge_loader, get_available_modules, get_module_content, @@ -48,7 +38,7 @@ from .knowledge import ( SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, ) -# 🔥 新增:协作工具 +# 协作工具 from .tools import ( ThinkTool, ReflectTool, CreateVulnerabilityReportTool, @@ -57,24 +47,15 @@ from .tools import ( WaitForMessageTool, AgentFinishTool, ) -# 🔥 新增:遥测模块 +# 遥测模块 from .telemetry import Tracer, get_global_tracer, set_global_tracer __all__ = [ - # 核心 Runner - "AgentRunner", - "run_agent_task", - "LLMService", - - # LangGraph - "AuditState", - "create_audit_graph", - # 事件管理 "EventManager", "AgentEventEmitter", - + # Agent 类 "BaseAgent", "AgentConfig", @@ -83,8 +64,8 @@ __all__ = [ "ReconAgent", "AnalysisAgent", "VerificationAgent", - - # 🔥 核心模块 + + # 核心模块 "AgentState", "AgentStatus", "AgentRegistry", @@ -93,8 +74,8 @@ __all__ = [ "MessageType", "MessagePriority", "MessageBus", - - # 🔥 知识模块(基于RAG) + + # 知识模块(基于RAG) "KnowledgeLoader", "knowledge_loader", "get_available_modules", @@ -103,8 +84,8 @@ __all__ = [ "security_knowledge_rag", "SecurityKnowledgeQueryTool", "GetVulnerabilityKnowledgeTool", - - # 🔥 协作工具 + + # 协作工具 "ThinkTool", "ReflectTool", "CreateVulnerabilityReportTool", @@ -114,10 +95,9 @@ __all__ = [ "ViewAgentGraphTool", "WaitForMessageTool", "AgentFinishTool", - - # 🔥 遥测模块 + + # 遥测模块 "Tracer", "get_global_tracer", "set_global_tracer", ] - diff --git a/backend/app/services/agent/graph/__init__.py b/backend/app/services/agent/graph/__init__.py deleted file mode 100644 index 5bc17d1..0000000 --- a/backend/app/services/agent/graph/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -LangGraph 工作流模块 -使用状态图构建混合 Agent 审计流程 -""" - -from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human -from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode -from .runner import AgentRunner, run_agent_task, LLMService - -__all__ = [ - # 状态和图 - "AuditState", - "create_audit_graph", - "create_audit_graph_with_human", - - # 节点 - "ReconNode", - "AnalysisNode", - "VerificationNode", - "ReportNode", - "HumanReviewNode", - - # Runner - "AgentRunner", - "run_agent_task", - "LLMService", -] - diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py deleted file mode 100644 index 4b0eada..0000000 --- a/backend/app/services/agent/graph/audit_graph.py +++ /dev/null @@ -1,677 +0,0 @@ -""" -DeepAudit 审计工作流图 - LLM 驱动版 -使用 LangGraph 构建 LLM 驱动的 Agent 协作流程 - -重要改变:路由决策由 LLM 参与,而不是硬编码条件! -""" - -from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal -from datetime import datetime -import operator -import logging -import json - -from langgraph.graph import StateGraph, END -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode - -logger = logging.getLogger(__name__) - - -# ============ 状态定义 ============ - -class Finding(TypedDict): - """漏洞发现""" - id: str - vulnerability_type: str - severity: str - title: str - description: str - file_path: Optional[str] - line_start: Optional[int] - code_snippet: Optional[str] - is_verified: bool - confidence: float - source: str - - -class AuditState(TypedDict): - """ - 审计状态 - 在整个工作流中传递和更新 - """ - # 输入 - project_root: str - project_info: Dict[str, Any] - config: Dict[str, Any] - task_id: str - - # Recon 阶段输出 - tech_stack: Dict[str, Any] - entry_points: List[Dict[str, Any]] - high_risk_areas: List[str] - dependencies: Dict[str, Any] - - # Analysis 阶段输出 - findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现 - - # Verification 阶段输出 - verified_findings: List[Finding] - false_positives: List[str] - # 🔥 NEW: 验证后的完整 findings(用于替换原始 findings) - _verified_findings_update: Optional[List[Finding]] - - # 控制流 - 🔥 关键:LLM 可以设置这些来影响路由 - current_phase: str - iteration: int - max_iterations: int - should_continue_analysis: bool - - # 🔥 新增:LLM 的路由决策 - llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end" - llm_routing_reason: Optional[str] # LLM 的决策理由 - - # 🔥 新增:Agent 间协作的任务交接信息 - recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接 - analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接 - verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接 - - # 消息和事件 - messages: Annotated[List[Dict], operator.add] - events: Annotated[List[Dict], operator.add] - - # 最终输出 - summary: Optional[Dict[str, Any]] - security_score: Optional[int] - error: Optional[str] - - -# ============ LLM 路由决策器 ============ - -class LLMRouter: - """ - LLM 路由决策器 - 让 LLM 来决定下一步应该做什么 - """ - - def __init__(self, llm_service): - self.llm_service = llm_service - - async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]: - """Recon 后让 LLM 决定下一步""" - entry_points = state.get("entry_points", []) - high_risk_areas = state.get("high_risk_areas", []) - tech_stack = state.get("tech_stack", {}) - initial_findings = state.get("findings", []) - - prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。 - -## 信息收集结果 -- 入口点数量: {len(entry_points)} -- 高风险区域: {high_risk_areas[:10]} -- 技术栈: {tech_stack} -- 初步发现: {len(initial_findings)} 个 - -## 选项 -1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时) -2. "end" - 结束审计(仅当没有任何可分析内容时) - -请返回 JSON 格式: -{{"action": "analysis或end", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - # 提取 JSON - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if entry_points or high_risk_areas: - return {"action": "analysis", "reason": "有可分析内容"} - return {"action": "end", "reason": "没有发现入口点或高风险区域"} - - async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]: - """Analysis 后让 LLM 决定下一步""" - findings = state.get("findings", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - # 统计发现 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - for f in findings: - # 跳过非字典类型的 finding - if not isinstance(f, dict): - continue - sev = f.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。 - -## 分析结果 -- 总发现数: {len(findings)} -- 严重程度分布: {severity_counts} -- 当前迭代: {iteration}/{max_iterations} - -## 选项 -1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时) -2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时) -3. "report" - 生成报告(推荐:没有发现或已充分分析时) - -请返回 JSON 格式: -{{"action": "verification/analysis/report", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if not findings: - return {"action": "report", "reason": "没有发现漏洞"} - if len(findings) >= 3 or iteration >= max_iterations: - return {"action": "verification", "reason": "有足够的发现需要验证"} - return {"action": "analysis", "reason": "发现较少,继续分析"} - - async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]: - """Verification 后让 LLM 决定下一步""" - verified_findings = state.get("verified_findings", []) - false_positives = state.get("false_positives", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。 - -## 验证结果 -- 已确认漏洞: {len(verified_findings)} -- 误报数量: {len(false_positives)} -- 当前迭代: {iteration}/{max_iterations} - -## 选项 -1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时) -2. "report" - 生成最终报告(推荐:验证完成时) - -请返回 JSON 格式: -{{"action": "analysis/report", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if len(false_positives) > len(verified_findings) and iteration < max_iterations: - return {"action": "analysis", "reason": "误报率较高,需要重新分析"} - return {"action": "report", "reason": "验证完成,生成报告"} - - -# ============ 路由函数 (结合 LLM 决策) ============ - -def route_after_recon(state: AuditState) -> Literal["analysis", "end"]: - """ - Recon 后的路由决策 - 优先使用 LLM 的决策,否则使用默认逻辑 - """ - # 🔥 检查是否有错误 - if state.get("error") or state.get("current_phase") == "error": - logger.error(f"Recon phase has error, routing to end: {state.get('error')}") - return "end" - - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "end": - return "end" - return "analysis" - - # 默认逻辑(作为 fallback) - if not state.get("entry_points") and not state.get("high_risk_areas"): - return "end" - return "analysis" - - -def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]: - """ - Analysis 后的路由决策 - 优先使用 LLM 的决策 - """ - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "verification": - return "verification" - elif llm_action == "analysis": - return "analysis" - elif llm_action == "report": - return "report" - - # 默认逻辑 - findings = state.get("findings", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - should_continue = state.get("should_continue_analysis", False) - - if not findings: - return "report" - - if should_continue and iteration < max_iterations: - return "analysis" - - return "verification" - - -def route_after_verification(state: AuditState) -> Literal["analysis", "report"]: - """ - Verification 后的路由决策 - 优先使用 LLM 的决策 - """ - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "analysis": - return "analysis" - return "report" - - # 默认逻辑 - false_positives = state.get("false_positives", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations: - return "analysis" - - return "report" - - -# ============ 创建审计图 ============ - -def create_audit_graph( - recon_node, - analysis_node, - verification_node, - report_node, - checkpointer: Optional[MemorySaver] = None, - llm_service=None, # 用于 LLM 路由决策 -) -> StateGraph: - """ - 创建审计工作流图 - - Args: - recon_node: 信息收集节点 - analysis_node: 漏洞分析节点 - verification_node: 漏洞验证节点 - report_node: 报告生成节点 - checkpointer: 检查点存储器 - llm_service: LLM 服务(用于路由决策) - - Returns: - 编译后的 StateGraph - - 工作流结构: - - START - │ - ▼ - ┌──────┐ - │Recon │ 信息收集 (LLM 驱动) - └──┬───┘ - │ LLM 决定 - ▼ - ┌──────────┐ - │ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环) - └────┬─────┘ │ - │ LLM 决定 │ - ▼ │ - ┌────────────┐ │ - │Verification│────┘ 漏洞验证 (LLM 驱动,可回溯) - └─────┬──────┘ - │ LLM 决定 - ▼ - ┌──────────┐ - │ Report │ 报告生成 - └────┬─────┘ - │ - ▼ - END - """ - - # 创建状态图 - workflow = StateGraph(AuditState) - - # 如果有 LLM 服务,创建路由决策器 - llm_router = LLMRouter(llm_service) if llm_service else None - - # 包装节点以添加 LLM 路由决策 - async def recon_with_routing(state): - result = await recon_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_recon({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - async def analysis_with_routing(state): - result = await analysis_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_analysis({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - async def verification_with_routing(state): - result = await verification_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_verification({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - # 添加节点 - if llm_router: - workflow.add_node("recon", recon_with_routing) - workflow.add_node("analysis", analysis_with_routing) - workflow.add_node("verification", verification_with_routing) - else: - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) - workflow.add_node("verification", verification_node) - - workflow.add_node("report", report_node) - - # 设置入口点 - workflow.set_entry_point("recon") - - # 添加条件边 - workflow.add_conditional_edges( - "recon", - route_after_recon, - { - "analysis": "analysis", - "end": END, - } - ) - - workflow.add_conditional_edges( - "analysis", - route_after_analysis, - { - "verification": "verification", - "analysis": "analysis", - "report": "report", - } - ) - - workflow.add_conditional_edges( - "verification", - route_after_verification, - { - "analysis": "analysis", - "report": "report", - } - ) - - # Report -> END - workflow.add_edge("report", END) - - # 编译图 - if checkpointer: - return workflow.compile(checkpointer=checkpointer) - else: - return workflow.compile() - - -# ============ 带人机协作的审计图 ============ - -def create_audit_graph_with_human( - recon_node, - analysis_node, - verification_node, - report_node, - human_review_node, - checkpointer: Optional[MemorySaver] = None, - llm_service=None, -) -> StateGraph: - """ - 创建带人机协作的审计工作流图 - - 在验证阶段后增加人工审核节点 - """ - - workflow = StateGraph(AuditState) - llm_router = LLMRouter(llm_service) if llm_service else None - - # 包装节点 - async def recon_with_routing(state): - result = await recon_node(state) - if llm_router: - decision = await llm_router.decide_after_recon({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - return result - - async def analysis_with_routing(state): - result = await analysis_node(state) - if llm_router: - decision = await llm_router.decide_after_analysis({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - return result - - # 添加节点 - if llm_router: - workflow.add_node("recon", recon_with_routing) - workflow.add_node("analysis", analysis_with_routing) - else: - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) - - workflow.add_node("verification", verification_node) - workflow.add_node("human_review", human_review_node) - workflow.add_node("report", report_node) - - workflow.set_entry_point("recon") - - workflow.add_conditional_edges( - "recon", - route_after_recon, - {"analysis": "analysis", "end": END} - ) - - workflow.add_conditional_edges( - "analysis", - route_after_analysis, - { - "verification": "verification", - "analysis": "analysis", - "report": "report", - } - ) - - # Verification -> Human Review - workflow.add_edge("verification", "human_review") - - # Human Review 后的路由 - def route_after_human(state: AuditState) -> Literal["analysis", "report"]: - if state.get("should_continue_analysis"): - return "analysis" - return "report" - - workflow.add_conditional_edges( - "human_review", - route_after_human, - {"analysis": "analysis", "report": "report"} - ) - - workflow.add_edge("report", END) - - if checkpointer: - return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"]) - else: - return workflow.compile() - - -# ============ 执行器 ============ - -class AuditGraphRunner: - """ - 审计图执行器 - 封装 LangGraph 工作流的执行 - """ - - def __init__( - self, - graph: StateGraph, - event_emitter=None, - ): - self.graph = graph - self.event_emitter = event_emitter - - async def run( - self, - project_root: str, - project_info: Dict[str, Any], - config: Dict[str, Any], - task_id: str, - ) -> Dict[str, Any]: - """ - 执行审计工作流 - """ - # 初始状态 - initial_state: AuditState = { - "project_root": project_root, - "project_info": project_info, - "config": config, - "task_id": task_id, - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新 - "current_phase": "start", - "iteration": 0, - "max_iterations": config.get("max_iterations", 3), - "should_continue_analysis": False, - "llm_next_action": None, - "llm_routing_reason": None, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - run_config = { - "configurable": { - "thread_id": task_id, - } - } - - try: - async for event in self.graph.astream(initial_state, config=run_config): - if self.event_emitter: - for node_name, node_state in event.items(): - await self.event_emitter.emit_info( - f"节点 {node_name} 完成" - ) - - # 发射 LLM 路由决策事件 - if node_state.get("llm_routing_reason"): - await self.event_emitter.emit_info( - f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}" - ) - - if node_name == "analysis" and node_state.get("findings"): - new_findings = node_state["findings"] - await self.event_emitter.emit_info( - f"发现 {len(new_findings)} 个潜在漏洞" - ) - - final_state = self.graph.get_state(run_config) - return final_state.values - - except Exception as e: - logger.error(f"Graph execution failed: {e}", exc_info=True) - raise - - async def run_with_human_review( - self, - initial_state: AuditState, - human_feedback_callback, - ) -> Dict[str, Any]: - """带人机协作的执行""" - run_config = { - "configurable": { - "thread_id": initial_state["task_id"], - } - } - - async for event in self.graph.astream(initial_state, config=run_config): - pass - - current_state = self.graph.get_state(run_config) - - if current_state.next == ("human_review",): - human_decision = await human_feedback_callback(current_state.values) - - updated_state = { - **current_state.values, - "should_continue_analysis": human_decision.get("continue_analysis", False), - } - - async for event in self.graph.astream(updated_state, config=run_config): - pass - - return self.graph.get_state(run_config).values diff --git a/backend/app/services/agent/graph/nodes.py b/backend/app/services/agent/graph/nodes.py deleted file mode 100644 index 0f8b034..0000000 --- a/backend/app/services/agent/graph/nodes.py +++ /dev/null @@ -1,556 +0,0 @@ -""" -LangGraph 节点实现 -每个节点封装一个 Agent 的执行逻辑 - -协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察 -""" - -from typing import Dict, Any, List, Optional -import logging - -logger = logging.getLogger(__name__) - -# 延迟导入避免循环依赖 -def get_audit_state_type(): - from .audit_graph import AuditState - return AuditState - - -class BaseNode: - """节点基类""" - - def __init__(self, agent=None, event_emitter=None): - self.agent = agent - self.event_emitter = event_emitter - - async def emit_event(self, event_type: str, message: str, **kwargs): - """发射事件""" - if self.event_emitter: - try: - await self.event_emitter.emit_info(message) - except Exception as e: - logger.warning(f"Failed to emit event: {e}") - - def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str): - """从状态中提取前序 Agent 的 handoff""" - handoff_data = state.get(f"{from_phase}_handoff") - if handoff_data: - from ..agents.base import TaskHandoff - return TaskHandoff.from_dict(handoff_data) - return None - - -class ReconNode(BaseNode): - """ - 信息收集节点 - - 输入: project_root, project_info, config - 输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行信息收集""" - await self.emit_event("phase_start", "🔍 开始信息收集阶段") - - try: - # 调用 Recon Agent - result = await self.agent.run({ - "project_info": state["project_info"], - "config": state["config"], - }) - - if result.success and result.data: - data = result.data - - # 🔥 创建交接信息给 Analysis Agent - handoff = self.agent.create_handoff( - to_agent="Analysis", - summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。", - key_findings=data.get("initial_findings", []), - suggested_actions=[ - { - "type": "deep_analysis", - "description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}", - "priority": "high", - }, - { - "type": "entry_point_audit", - "description": "审计所有入口点的输入验证", - "priority": "high", - }, - ], - attention_points=[ - f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}", - f"主要语言: {data.get('tech_stack', {}).get('languages', [])}", - ], - priority_areas=data.get("high_risk_areas", [])[:10], - context_data={ - "tech_stack": data.get("tech_stack", {}), - "entry_points": data.get("entry_points", []), - "dependencies": data.get("dependencies", {}), - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点" - ) - - return { - "tech_stack": data.get("tech_stack", {}), - "entry_points": data.get("entry_points", []), - "high_risk_areas": data.get("high_risk_areas", []), - "dependencies": data.get("dependencies", {}), - "current_phase": "recon_complete", - "findings": data.get("initial_findings", []), - # 🔥 保存交接信息 - "recon_handoff": handoff.to_dict(), - "events": [{ - "type": "recon_complete", - "data": { - "entry_points_count": len(data.get("entry_points", [])), - "high_risk_areas_count": len(data.get("high_risk_areas", [])), - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "error": result.error or "Recon failed", - "current_phase": "error", - } - - except Exception as e: - logger.error(f"Recon node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class AnalysisNode(BaseNode): - """ - 漏洞分析节点 - - 输入: tech_stack, entry_points, high_risk_areas, recon_handoff - 输出: findings (累加), should_continue_analysis, analysis_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行漏洞分析""" - iteration = state.get("iteration", 0) + 1 - - await self.emit_event( - "phase_start", - f"🔬 开始漏洞分析阶段 (迭代 {iteration})" - ) - - try: - # 🔥 提取 Recon 的交接信息 - recon_handoff = self._extract_handoff_from_state(state, "recon") - if recon_handoff: - self.agent.receive_handoff(recon_handoff) - await self.emit_event( - "handoff_received", - f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..." - ) - - # 构建分析输入 - analysis_input = { - "phase_name": "analysis", - "project_info": state["project_info"], - "config": state["config"], - "plan": { - "high_risk_areas": state.get("high_risk_areas", []), - }, - "previous_results": { - "recon": { - "data": { - "tech_stack": state.get("tech_stack", {}), - "entry_points": state.get("entry_points", []), - "high_risk_areas": state.get("high_risk_areas", []), - } - } - }, - # 🔥 传递交接信息 - "handoff": recon_handoff, - } - - # 调用 Analysis Agent - result = await self.agent.run(analysis_input) - - if result.success and result.data: - new_findings = result.data.get("findings", []) - logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings") - - # 判断是否需要继续分析 - should_continue = ( - len(new_findings) >= 5 and - iteration < state.get("max_iterations", 3) - ) - - # 🔥 创建交接信息给 Verification Agent - # 统计严重程度 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - for f in new_findings: - if isinstance(f, dict): - sev = f.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - handoff = self.agent.create_handoff( - to_agent="Verification", - summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})", - key_findings=new_findings[:20], # 传递前20个发现 - suggested_actions=[ - { - "type": "verify_critical", - "description": "优先验证 Critical 和 High 级别的漏洞", - "priority": "critical", - }, - { - "type": "poc_generation", - "description": "为确认的漏洞生成 PoC", - "priority": "high", - }, - ], - attention_points=[ - f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证", - f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证", - "注意检查是否有误报,特别是静态分析工具的结果", - ], - priority_areas=[ - f.get("file_path", "") for f in new_findings - if f.get("severity") in ["critical", "high"] - ][:10], - context_data={ - "severity_distribution": severity_counts, - "total_findings": len(new_findings), - "iteration": iteration, - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞" - ) - - return { - "findings": new_findings, - "iteration": iteration, - "should_continue_analysis": should_continue, - "current_phase": "analysis_complete", - # 🔥 保存交接信息 - "analysis_handoff": handoff.to_dict(), - "events": [{ - "type": "analysis_iteration", - "data": { - "iteration": iteration, - "findings_count": len(new_findings), - "severity_distribution": severity_counts, - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "iteration": iteration, - "should_continue_analysis": False, - "current_phase": "analysis_complete", - } - - except Exception as e: - logger.error(f"Analysis node failed: {e}", exc_info=True) - return { - "error": str(e), - "should_continue_analysis": False, - "current_phase": "error", - } - - -class VerificationNode(BaseNode): - """ - 漏洞验证节点 - - 输入: findings, analysis_handoff - 输出: verified_findings, false_positives, verification_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行漏洞验证""" - findings = state.get("findings", []) - logger.info(f"[VerificationNode] Received {len(findings)} findings to verify") - - if not findings: - return { - "verified_findings": [], - "false_positives": [], - "current_phase": "verification_complete", - } - - await self.emit_event( - "phase_start", - f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)" - ) - - try: - # 🔥 提取 Analysis 的交接信息 - analysis_handoff = self._extract_handoff_from_state(state, "analysis") - if analysis_handoff: - self.agent.receive_handoff(analysis_handoff) - await self.emit_event( - "handoff_received", - f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..." - ) - - # 构建验证输入 - verification_input = { - "previous_results": { - "analysis": { - "data": { - "findings": findings, - } - } - }, - "config": state["config"], - # 🔥 传递交接信息 - "handoff": analysis_handoff, - } - - # 调用 Verification Agent - result = await self.agent.run(verification_input) - - if result.success and result.data: - all_verified_findings = result.data.get("findings", []) - verified = [f for f in all_verified_findings if f.get("is_verified")] - false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings - if f.get("verdict") == "false_positive"] - - # 🔥 CRITICAL FIX: 用验证结果更新原始 findings - # 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type) - verified_map = {} - for vf in all_verified_findings: - key = ( - vf.get("file_path", ""), - vf.get("line_start", 0), - vf.get("vulnerability_type", ""), - ) - verified_map[key] = vf - - # 合并验证结果到原始 findings - updated_findings = [] - seen_keys = set() - - # 首先处理原始 findings,用验证结果更新 - for f in findings: - if not isinstance(f, dict): - continue - key = ( - f.get("file_path", ""), - f.get("line_start", 0), - f.get("vulnerability_type", ""), - ) - if key in verified_map: - # 使用验证后的版本 - updated_findings.append(verified_map[key]) - seen_keys.add(key) - else: - # 保留原始(未验证) - updated_findings.append(f) - seen_keys.add(key) - - # 添加验证结果中的新发现(如果有) - for key, vf in verified_map.items(): - if key not in seen_keys: - updated_findings.append(vf) - - logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified") - - # 🔥 创建交接信息给 Report 节点 - handoff = self.agent.create_handoff( - to_agent="Report", - summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。", - key_findings=verified, - suggested_actions=[ - { - "type": "generate_report", - "description": "生成详细的安全审计报告", - "priority": "high", - }, - { - "type": "remediation_plan", - "description": "为确认的漏洞制定修复计划", - "priority": "high", - }, - ], - attention_points=[ - f"共 {len(verified)} 个漏洞已确认存在", - f"共 {len(false_pos)} 个误报已排除", - "建议按严重程度优先修复 Critical 和 High 级别漏洞", - ], - context_data={ - "verified_count": len(verified), - "false_positive_count": len(false_pos), - "total_analyzed": len(findings), - "verification_rate": len(verified) / len(findings) if findings else 0, - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报" - ) - - return { - # 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings - # 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并 - # 这里我们返回 _verified_findings_update 作为特殊字段 - "_verified_findings_update": updated_findings, - "verified_findings": verified, - "false_positives": false_pos, - "current_phase": "verification_complete", - # 🔥 保存交接信息 - "verification_handoff": handoff.to_dict(), - "events": [{ - "type": "verification_complete", - "data": { - "verified_count": len(verified), - "false_positive_count": len(false_pos), - "total_findings": len(updated_findings), - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "verified_findings": [], - "false_positives": [], - "current_phase": "verification_complete", - } - - except Exception as e: - logger.error(f"Verification node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class ReportNode(BaseNode): - """ - 报告生成节点 - - 输入: all state - 输出: summary, security_score - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """生成审计报告""" - await self.emit_event("phase_start", "📊 生成审计报告") - - try: - # 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新 - findings = state.get("_verified_findings_update") or state.get("findings", []) - verified = state.get("verified_findings", []) - false_positives = state.get("false_positives", []) - - logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified") - - # 统计漏洞分布 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - type_counts = {} - - for finding in findings: - # 跳过非字典类型的 finding(防止数据格式异常) - if not isinstance(finding, dict): - logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}") - continue - - sev = finding.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - vtype = finding.get("vulnerability_type", "other") - type_counts[vtype] = type_counts.get(vtype, 0) + 1 - - # 计算安全评分 - base_score = 100 - deductions = ( - severity_counts["critical"] * 25 + - severity_counts["high"] * 15 + - severity_counts["medium"] * 8 + - severity_counts["low"] * 3 - ) - security_score = max(0, base_score - deductions) - - # 生成摘要 - summary = { - "total_findings": len(findings), - "verified_count": len(verified), - "false_positive_count": len(false_positives), - "severity_distribution": severity_counts, - "vulnerability_types": type_counts, - "tech_stack": state.get("tech_stack", {}), - "entry_points_analyzed": len(state.get("entry_points", [])), - "high_risk_areas": state.get("high_risk_areas", []), - "iterations": state.get("iteration", 1), - } - - await self.emit_event( - "phase_complete", - f"报告生成完成: 安全评分 {security_score}/100" - ) - - return { - "summary": summary, - "security_score": security_score, - "current_phase": "complete", - "events": [{ - "type": "audit_complete", - "data": { - "security_score": security_score, - "total_findings": len(findings), - "verified_count": len(verified), - } - }], - } - - except Exception as e: - logger.error(f"Report node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class HumanReviewNode(BaseNode): - """ - 人工审核节点 - - 在此节点暂停,等待人工反馈 - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - 人工审核节点 - - 这个节点会被 interrupt_before 暂停 - 用户可以: - 1. 确认发现 - 2. 标记误报 - 3. 请求重新分析 - """ - await self.emit_event( - "human_review", - f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)" - ) - - # 返回当前状态,不做修改 - # 人工反馈会通过 update_state 传入 - return { - "current_phase": "human_review", - "messages": [{ - "role": "system", - "content": "等待人工审核", - "findings_for_review": state.get("verified_findings", []), - }], - } - diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py deleted file mode 100644 index e0afd12..0000000 --- a/backend/app/services/agent/graph/runner.py +++ /dev/null @@ -1,1042 +0,0 @@ -""" -DeepAudit LangGraph Runner -基于 LangGraph 的 Agent 审计执行器 -""" - -import asyncio -import logging -import os -import uuid -from datetime import datetime, timezone -from typing import Dict, List, Optional, Any, AsyncGenerator - -from sqlalchemy.ext.asyncio import AsyncSession - -from langgraph.graph import StateGraph, END -from langgraph.checkpoint.memory import MemorySaver - -from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType -from app.models.agent_task import ( - AgentTask, AgentEvent, AgentFinding, - AgentTaskStatus, AgentTaskPhase, AgentEventType, - VulnerabilitySeverity, VulnerabilityType, FindingStatus, -) -from app.services.agent.event_manager import EventManager, AgentEventEmitter -from app.services.agent.tools import ( - RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool, - PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool, - FileReadTool, FileSearchTool, ListFilesTool, - SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager, - SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool, - TruffleHogTool, OSVScannerTool, -) -from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService -from app.core.config import settings - -from .audit_graph import AuditState, create_audit_graph -from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode - -logger = logging.getLogger(__name__) - - -# 🔥 使用系统统一的 LLMService(支持用户配置) -from app.services.llm.service import LLMService - - -class AgentRunner: - """ - DeepAudit LangGraph Agent Runner - - 基于 LangGraph 状态图的审计执行器 - - 工作流: - START → Recon → Analysis ⟲ → Verification → Report → END - """ - - def __init__( - self, - db: AsyncSession, - task: AgentTask, - project_root: str, - user_config: Optional[Dict[str, Any]] = None, - ): - self.db = db - self.task = task - self.project_root = project_root - - # 🔥 保存用户配置,供 RAG 初始化使用 - self.user_config = user_config or {} - - # 事件管理 - 传入 db_session_factory 以持久化事件 - from app.db.session import async_session_factory - self.event_manager = EventManager(db_session_factory=async_session_factory) - self.event_emitter = AgentEventEmitter(task.id, self.event_manager) - - # 🔥 CRITICAL: 立即创建事件队列,确保在 Agent 开始执行前队列就存在 - # 这样即使前端 SSE 连接稍晚,token 事件也不会丢失 - self.event_manager.create_queue(task.id) - - # 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取) - self.llm_service = LLMService(user_config=self.user_config) - - # 工具集 - self.tools: Dict[str, Any] = {} - - # RAG 组件 - self.retriever: Optional[CodeRetriever] = None - self.indexer: Optional[CodeIndexer] = None - - # 沙箱 - self.sandbox_manager: Optional[SandboxManager] = None - - # LangGraph - self.graph: Optional[StateGraph] = None - self.checkpointer = MemorySaver() - - # 状态 - self._cancelled = False - self._running_task: Optional[asyncio.Task] = None - - # Agent 引用(用于取消传播) - self._agents: List[Any] = [] - - # 流式处理器 - self.stream_handler = StreamHandler(task.id) - - def cancel(self): - """取消任务""" - self._cancelled = True - - # 🔥 取消所有 Agent - for agent in self._agents: - if hasattr(agent, 'cancel'): - agent.cancel() - logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}") - - # 取消运行中的任务 - if self._running_task and not self._running_task.done(): - self._running_task.cancel() - - logger.info(f"Task {self.task.id} cancellation requested") - - @property - def is_cancelled(self) -> bool: - """检查是否已取消""" - return self._cancelled - - async def initialize(self): - """初始化 Runner""" - await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...") - - # 1. 初始化 RAG 系统 - await self._initialize_rag() - - # 2. 初始化工具 - await self._initialize_tools() - - # 3. 构建 LangGraph - await self._build_graph() - - await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成") - - async def _initialize_rag(self): - """初始化 RAG 系统""" - await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...") - - try: - # 🔥 从用户配置中获取配置 - # 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - user_llm_config = self.user_config.get('llmConfig', {}) - user_other_config = self.user_config.get('otherConfig', {}) - user_embedding_config = user_other_config.get('embedding_config', {}) - - # 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量 - embedding_provider = ( - user_embedding_config.get('provider') or - getattr(settings, 'EMBEDDING_PROVIDER', 'openai') - ) - - # 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量 - embedding_model = ( - user_embedding_config.get('model') or - getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') - ) - - # 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - embedding_api_key = ( - user_embedding_config.get('api_key') or - user_llm_config.get('llmApiKey') or - getattr(settings, 'LLM_API_KEY', '') or - '' - ) - - # 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - embedding_base_url = ( - user_embedding_config.get('base_url') or - user_llm_config.get('llmBaseUrl') or - getattr(settings, 'LLM_BASE_URL', None) or - None - ) - - logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}") - await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}") - - embedding_service = EmbeddingService( - provider=embedding_provider, - model=embedding_model, - api_key=embedding_api_key, - base_url=embedding_base_url, - ) - - self.indexer = CodeIndexer( - collection_name=f"project_{self.task.project_id}", - embedding_service=embedding_service, - persist_directory=settings.VECTOR_DB_PATH, - ) - - self.retriever = CodeRetriever( - collection_name=f"project_{self.task.project_id}", - embedding_service=embedding_service, - persist_directory=settings.VECTOR_DB_PATH, - ) - - except Exception as e: - logger.warning(f"RAG initialization failed: {e}") - await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}") - - async def _initialize_tools(self): - """初始化工具集""" - await self.event_emitter.emit_info("初始化 Agent 工具集...") - - # 🔥 导入新工具 - from app.services.agent.tools import ( - ThinkTool, ReflectTool, - CreateVulnerabilityReportTool, - # 多语言代码测试工具 - PhpTestTool, PythonTestTool, JavaScriptTestTool, JavaTestTool, - GoTestTool, RubyTestTool, ShellTestTool, UniversalCodeTestTool, - # 漏洞验证专用工具 - CommandInjectionTestTool, SqlInjectionTestTool, XssTestTool, - PathTraversalTestTool, SstiTestTool, DeserializationTestTool, - UniversalVulnTestTool, - # Kunlun-M 静态代码分析工具 (MIT License) - KunlunMTool, KunlunRuleListTool, KunlunPluginTool, - ) - # 🔥 导入知识查询工具 - from app.services.agent.knowledge import ( - SecurityKnowledgeQueryTool, - GetVulnerabilityKnowledgeTool, - ) - - # 🔥 获取排除模式和目标文件 - exclude_patterns = self.task.exclude_patterns or [] - target_files = self.task.target_files or None - - # ============ 🔥 提前初始化 SandboxManager(供所有外部工具共享)============ - self.sandbox_manager = None - try: - from app.services.agent.tools.sandbox_tool import SandboxConfig - sandbox_config = SandboxConfig( - image=settings.SANDBOX_IMAGE, - memory_limit=settings.SANDBOX_MEMORY_LIMIT, - cpu_limit=settings.SANDBOX_CPU_LIMIT, - timeout=settings.SANDBOX_TIMEOUT, - network_mode=settings.SANDBOX_NETWORK_MODE, - ) - self.sandbox_manager = SandboxManager(config=sandbox_config) - # 🔥 必须调用 initialize() 来连接 Docker - await self.sandbox_manager.initialize() - logger.info(f"✅ SandboxManager initialized early (Docker available: {self.sandbox_manager.is_available})") - except Exception as e: - logger.warning(f"❌ Early Sandbox Manager initialization failed: {e}") - import traceback - logger.warning(f"Traceback: {traceback.format_exc()}") - # 尝试创建默认管理器作为后备 - try: - self.sandbox_manager = SandboxManager() - await self.sandbox_manager.initialize() - logger.info(f"⚠️ Created fallback SandboxManager (Docker available: {self.sandbox_manager.is_available})") - except Exception as e2: - logger.error(f"❌ Failed to create fallback SandboxManager: {e2}") - - # ============ 基础工具(所有 Agent 共享)============ - base_tools = { - "read_file": FileReadTool(self.project_root, exclude_patterns, target_files), - "list_files": ListFilesTool(self.project_root, exclude_patterns, target_files), - # 🔥 新增:思考工具(所有Agent可用) - "think": ThinkTool(), - } - - # ============ Recon Agent 专属工具 ============ - # 职责:信息收集、项目结构分析、技术栈识别 - # 🔥 新增:外部工具也可用于Recon阶段的快速扫描 - self.recon_tools = { - **base_tools, - "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - # 🔥 外部安全工具(共享 SandboxManager 实例) - "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager), - "bandit_scan": BanditTool(self.project_root, self.sandbox_manager), - "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager), - "safety_scan": SafetyTool(self.project_root, self.sandbox_manager), - "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager), - } - - # RAG 工具(Recon 用于语义搜索) - if self.retriever: - self.recon_tools["rag_query"] = RAGQueryTool(self.retriever) - logger.info("✅ RAG 工具已注册到 Recon Agent") - else: - logger.warning("⚠️ RAG 未初始化,rag_query 工具不可用") - - # ============ Analysis Agent 专属工具 ============ - # 职责:漏洞分析、代码审计、模式匹配 - self.analysis_tools = { - **base_tools, - "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), - # 模式匹配和代码分析 - "pattern_match": PatternMatchTool(self.project_root), - # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败 - # "code_analysis": CodeAnalysisTool(self.llm_service), - "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), - # 🔥 外部静态分析工具(共享 SandboxManager 实例) - "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager), - "bandit_scan": BanditTool(self.project_root, self.sandbox_manager), - "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager), - "trufflehog_scan": TruffleHogTool(self.project_root, self.sandbox_manager), - "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager), - "safety_scan": SafetyTool(self.project_root, self.sandbox_manager), - "osv_scan": OSVScannerTool(self.project_root, self.sandbox_manager), - # 🔥 Kunlun-M 静态代码分析工具 (MIT License - https://github.com/LoRexxar/Kunlun-M) - "kunlun_scan": KunlunMTool(self.project_root), - "kunlun_list_rules": KunlunRuleListTool(self.project_root), - "kunlun_plugin": KunlunPluginTool(self.project_root), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - # 🔥 新增:安全知识查询工具(基于RAG) - "query_security_knowledge": SecurityKnowledgeQueryTool(), - "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), - } - - # RAG 工具(Analysis 用于安全相关代码搜索) - if self.retriever: - self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索 - self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索 - self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文 - logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)") - - # ============ Verification Agent 专属工具 ============ - # 职责:漏洞验证、PoC 执行、误报排除 - self.verification_tools = { - **base_tools, - # 验证工具 - 移除旧的 vulnerability_validation 和 dataflow_analysis,强制使用沙箱 - # 🔥 新增:漏洞报告工具(仅Verification可用)- v2.1: 传递 project_root - "create_vulnerability_report": CreateVulnerabilityReportTool(self.project_root), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - } - - # 🔥 注册沙箱工具(使用提前初始化的 SandboxManager) - if self.sandbox_manager: - # 🔥 沙箱核心工具 - self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) - self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) - self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) - - # 🔥 多语言代码测试工具 - self.verification_tools["php_test"] = PhpTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["python_test"] = PythonTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["javascript_test"] = JavaScriptTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["java_test"] = JavaTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["go_test"] = GoTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["ruby_test"] = RubyTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["shell_test"] = ShellTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["universal_code_test"] = UniversalCodeTestTool(self.sandbox_manager, self.project_root) - - # 🔥 漏洞验证专用工具 - self.verification_tools["test_command_injection"] = CommandInjectionTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_sql_injection"] = SqlInjectionTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_xss"] = XssTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_path_traversal"] = PathTraversalTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_ssti"] = SstiTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_deserialization"] = DeserializationTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["universal_vuln_test"] = UniversalVulnTestTool(self.sandbox_manager, self.project_root) - - logger.info(f"✅ Sandbox tools initialized (Docker available: {self.sandbox_manager.is_available})") - else: - logger.error("❌ Sandbox tools NOT initialized due to critical manager failure") - - logger.info(f"✅ Verification tools: {list(self.verification_tools.keys())}") - - # 统计总工具数 - total_tools = len(set( - list(self.recon_tools.keys()) + - list(self.analysis_tools.keys()) + - list(self.verification_tools.keys()) - )) - await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具") - - async def _build_graph(self): - """构建 LangGraph 审计图""" - await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...") - - # 导入 Agent - from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent - - # 创建 Agent 实例(每个 Agent 使用专属工具集) - recon_agent = ReconAgent( - llm_service=self.llm_service, - tools=self.recon_tools, # Recon 专属工具 - event_emitter=self.event_emitter, - ) - - analysis_agent = AnalysisAgent( - llm_service=self.llm_service, - tools=self.analysis_tools, # Analysis 专属工具 - event_emitter=self.event_emitter, - ) - - verification_agent = VerificationAgent( - llm_service=self.llm_service, - tools=self.verification_tools, # Verification 专属工具 - event_emitter=self.event_emitter, - ) - - # 🔥 保存 Agent 引用以便取消时传播信号 - self._agents = [recon_agent, analysis_agent, verification_agent] - - # 创建节点 - recon_node = ReconNode(recon_agent, self.event_emitter) - analysis_node = AnalysisNode(analysis_agent, self.event_emitter) - verification_node = VerificationNode(verification_agent, self.event_emitter) - report_node = ReportNode(None, self.event_emitter) - - # 构建图 - self.graph = create_audit_graph( - recon_node=recon_node, - analysis_node=analysis_node, - verification_node=verification_node, - report_node=report_node, - checkpointer=self.checkpointer, - ) - - await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成") - - async def run(self) -> Dict[str, Any]: - """ - 执行 LangGraph 审计 - - Returns: - 最终状态 - """ - final_state = {} - try: - async for event in self.run_with_streaming(): - # 收集最终状态 - if event.event_type == StreamEventType.TASK_COMPLETE: - final_state = event.data - elif event.event_type == StreamEventType.TASK_ERROR: - final_state = {"success": False, "error": event.data.get("error")} - except Exception as e: - logger.error(f"Agent run failed: {e}", exc_info=True) - final_state = {"success": False, "error": str(e)} - - return final_state - - async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]: - """ - 带流式输出的审计执行 - - Yields: - StreamEvent: 流式事件(包含 LLM 思考、工具调用等) - """ - import time - start_time = time.time() - - try: - # 初始化 - await self.initialize() - - # 更新任务状态 - await self._update_task_status(AgentTaskStatus.RUNNING) - - # 发射任务开始事件 - yield StreamEvent( - event_type=StreamEventType.TASK_START, - sequence=self.stream_handler._next_sequence(), - data={"task_id": self.task.id, "message": "🚀 审计任务开始"}, - ) - - # 1. 索引代码 - await self._index_code() - - if self._cancelled: - yield StreamEvent( - event_type=StreamEventType.TASK_CANCEL, - sequence=self.stream_handler._next_sequence(), - data={"message": "任务已取消"}, - ) - return - - # 2. 收集项目信息 - project_info = await self._collect_project_info() - - # 3. 构建初始状态 - task_config = { - "target_vulnerabilities": self.task.target_vulnerabilities or [], - "verification_level": self.task.verification_level or "sandbox", - "exclude_patterns": self.task.exclude_patterns or [], - "target_files": self.task.target_files or [], - "max_iterations": self.task.max_iterations or 50, - "timeout_seconds": self.task.timeout_seconds or 1800, - } - - initial_state: AuditState = { - "project_root": self.project_root, - "project_info": project_info, - "config": task_config, - "task_id": self.task.id, - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新 - "current_phase": "start", - "iteration": 0, - "max_iterations": self.task.max_iterations or 50, - "should_continue_analysis": False, - # 🔥 Agent 协作交接信息 - "recon_handoff": None, - "analysis_handoff": None, - "verification_handoff": None, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - # 4. 执行 LangGraph with astream_events - await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流") - - run_config = { - "configurable": { - "thread_id": self.task.id, - } - } - - final_state = None - - # 使用 astream_events 获取详细事件流 - try: - async for event in self.graph.astream_events( - initial_state, - config=run_config, - version="v2", - ): - if self._cancelled: - break - - # 处理 LangGraph 事件 - stream_event = await self.stream_handler.process_langgraph_event(event) - if stream_event: - # 同步到 event_emitter 以持久化 - await self._sync_stream_event_to_db(stream_event) - yield stream_event - - # 更新最终状态 - if event.get("event") == "on_chain_end": - output = event.get("data", {}).get("output") - if isinstance(output, dict): - final_state = output - - except Exception as e: - # 如果 astream_events 不可用,回退到 astream - logger.warning(f"astream_events not available, falling back to astream: {e}") - async for event in self.graph.astream(initial_state, config=run_config): - if self._cancelled: - break - - for node_name, node_output in event.items(): - await self._handle_node_output(node_name, node_output) - - # 发射节点事件 - yield StreamEvent( - event_type=StreamEventType.NODE_END, - sequence=self.stream_handler._next_sequence(), - node_name=node_name, - data={"message": f"节点 {node_name} 完成"}, - ) - - phase_map = { - "recon": AgentTaskPhase.RECONNAISSANCE, - "analysis": AgentTaskPhase.ANALYSIS, - "verification": AgentTaskPhase.VERIFICATION, - "report": AgentTaskPhase.REPORTING, - } - if node_name in phase_map: - await self._update_task_phase(phase_map[node_name]) - - final_state = node_output - - # 5. 获取最终状态 - # 🔥 CRITICAL FIX: 始终从 graph 获取完整的累积状态 - # 因为每个节点只返回自己的输出,findings 等字段是通过 operator.add 累积的 - # 直接使用 node_output 会丢失之前节点累积的 findings - graph_state = self.graph.get_state(run_config) - if graph_state and graph_state.values: - # 合并完整状态和最后节点的输出 - full_state = graph_state.values - if final_state: - # 保留最后节点的输出(如 summary, security_score) - full_state = {**full_state, **final_state} - final_state = full_state - logger.info(f"[Runner] Got full state from graph with {len(final_state.get('findings', []))} findings") - elif not final_state: - final_state = {} - logger.warning("[Runner] No final state available from graph") - - # 🔥 CRITICAL FIX: 如果有验证后的 findings 更新,使用它替换原始 findings - # 这是因为 LangGraph 的 operator.add 累积器不适合更新已有 findings - verified_findings_update = final_state.get("_verified_findings_update") - if verified_findings_update: - logger.info(f"[Runner] Using verified findings update: {len(verified_findings_update)} findings") - final_state["findings"] = verified_findings_update - else: - # 🔥 FALLBACK: 如果没有 _verified_findings_update,尝试从 verified_findings 合并 - findings = final_state.get("findings", []) - verified_findings = final_state.get("verified_findings", []) - - if verified_findings and findings: - # 创建合并后的 findings 列表 - merged_findings = self._merge_findings_with_verification(findings, verified_findings) - final_state["findings"] = merged_findings - logger.info(f"[Runner] Merged findings: {len(merged_findings)} total") - elif verified_findings and not findings: - # 如果只有 verified_findings,直接使用 - final_state["findings"] = verified_findings - logger.info(f"[Runner] Using verified_findings directly: {len(verified_findings)}") - - logger.info(f"[Runner] Final findings count: {len(final_state.get('findings', []))}") - - # 🔥 检查是否有错误 - error = final_state.get("error") - if error: - # 检查是否是 LLM 认证错误 - error_str = str(error) - if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str: - error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。" - logger.error(f"LLM authentication error: {error}") - else: - error_message = error_str - - duration_ms = int((time.time() - start_time) * 1000) - - # 标记任务为失败 - await self._update_task_status(AgentTaskStatus.FAILED, error_message) - await self.event_emitter.emit_task_error(error_message) - - yield StreamEvent( - event_type=StreamEventType.TASK_ERROR, - sequence=self.stream_handler._next_sequence(), - data={ - "error": error_message, - "message": f"❌ 任务失败: {error_message}", - }, - ) - return - - # 6. 保存发现 - findings = final_state.get("findings", []) - await self._save_findings(findings) - - # 发射发现事件 - for finding in findings[:10]: # 限制数量 - yield self.stream_handler.create_finding_event( - finding, - is_verified=finding.get("is_verified", False), - ) - - # 7. 更新任务摘要 - summary = final_state.get("summary", {}) - security_score = final_state.get("security_score", 100) - - await self._update_task_summary( - total_findings=len(findings), - verified_count=len(final_state.get("verified_findings", [])), - security_score=security_score, - ) - - # 8. 完成 - duration_ms = int((time.time() - start_time) * 1000) - - await self._update_task_status(AgentTaskStatus.COMPLETED) - await self.event_emitter.emit_task_complete( - findings_count=len(findings), - duration_ms=duration_ms, - ) - - yield StreamEvent( - event_type=StreamEventType.TASK_COMPLETE, - sequence=self.stream_handler._next_sequence(), - data={ - "findings_count": len(findings), - "verified_count": len(final_state.get("verified_findings", [])), - "security_score": security_score, - "duration_ms": duration_ms, - "message": f"✅ 审计完成!发现 {len(findings)} 个漏洞", - }, - ) - - except asyncio.CancelledError: - await self._update_task_status(AgentTaskStatus.CANCELLED) - yield StreamEvent( - event_type=StreamEventType.TASK_CANCEL, - sequence=self.stream_handler._next_sequence(), - data={"message": "任务已取消"}, - ) - - except Exception as e: - logger.error(f"LangGraph run failed: {e}", exc_info=True) - await self._update_task_status(AgentTaskStatus.FAILED, str(e)) - await self.event_emitter.emit_error(str(e)) - - yield StreamEvent( - event_type=StreamEventType.TASK_ERROR, - sequence=self.stream_handler._next_sequence(), - data={"error": str(e), "message": f"❌ 审计失败: {e}"}, - ) - - finally: - await self._cleanup() - - async def _sync_stream_event_to_db(self, event: StreamEvent): - """同步流式事件到数据库""" - try: - # 将 StreamEvent 转换为 AgentEventData - await self.event_manager.add_event( - task_id=self.task.id, - event_type=event.event_type.value, - sequence=event.sequence, - phase=event.phase, - message=event.data.get("message"), - tool_name=event.tool_name, - tool_input=event.data.get("input") or event.data.get("input_params"), - tool_output=event.data.get("output") or event.data.get("output_data"), - tool_duration_ms=event.data.get("duration_ms"), - metadata=event.data, - ) - except Exception as e: - logger.warning(f"Failed to sync stream event to db: {e}") - - async def _handle_node_output(self, node_name: str, output: Dict[str, Any]): - """处理节点输出""" - # 发射节点事件 - events = output.get("events", []) - for evt in events: - await self.event_emitter.emit_info( - f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}" - ) - - # 处理新发现 - if node_name == "analysis": - new_findings = output.get("findings", []) - if new_findings: - for finding in new_findings[:5]: # 限制事件数量 - await self.event_emitter.emit_finding( - title=finding.get("title", "Unknown"), - severity=finding.get("severity", "medium"), - file_path=finding.get("file_path"), - ) - - # 处理验证结果 - if node_name == "verification": - verified = output.get("verified_findings", []) - for v in verified[:5]: - await self.event_emitter.emit_info( - f"✅ 已验证: {v.get('title', 'Unknown')}" - ) - - # 处理错误 - if output.get("error"): - await self.event_emitter.emit_error(output["error"]) - - async def _index_code(self): - """索引代码""" - if not self.indexer: - await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引") - return - - await self._update_task_phase(AgentTaskPhase.INDEXING) - await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引") - - try: - async for progress in self.indexer.index_directory(self.project_root): - if self._cancelled: - return - - await self.event_emitter.emit_progress( - progress.processed_files, - progress.total_files, - f"正在索引: {progress.current_file or 'N/A'}" - ) - - await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成") - - except Exception as e: - logger.warning(f"Code indexing failed: {e}") - await self.event_emitter.emit_warning(f"代码索引失败: {e}") - - async def _collect_project_info(self) -> Dict[str, Any]: - """收集项目信息""" - info = { - "name": self.task.project.name if self.task.project else "unknown", - "root": self.project_root, - "languages": [], - "file_count": 0, - } - - try: - exclude_dirs = { - "node_modules", "__pycache__", ".git", "venv", ".venv", - "build", "dist", "target", ".idea", ".vscode", - } - - for root, dirs, files in os.walk(self.project_root): - dirs[:] = [d for d in dirs if d not in exclude_dirs] - info["file_count"] += len(files) - - lang_map = { - ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", - ".java": "Java", ".go": "Go", ".php": "PHP", - ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++", - } - - for f in files: - ext = os.path.splitext(f)[1].lower() - if ext in lang_map and lang_map[ext] not in info["languages"]: - info["languages"].append(lang_map[ext]) - - except Exception as e: - logger.warning(f"Failed to collect project info: {e}") - - return info - - async def _save_findings(self, findings: List[Dict]): - """保存发现到数据库""" - logger.info(f"[Runner] Saving {len(findings)} findings to database for task {self.task.id}") - - if not findings: - logger.info("[Runner] No findings to save") - return - - severity_map = { - "critical": VulnerabilitySeverity.CRITICAL, - "high": VulnerabilitySeverity.HIGH, - "medium": VulnerabilitySeverity.MEDIUM, - "low": VulnerabilitySeverity.LOW, - "info": VulnerabilitySeverity.INFO, - } - - type_map = { - "sql_injection": VulnerabilityType.SQL_INJECTION, - "nosql_injection": VulnerabilityType.NOSQL_INJECTION, - "xss": VulnerabilityType.XSS, - "command_injection": VulnerabilityType.COMMAND_INJECTION, - "code_injection": VulnerabilityType.CODE_INJECTION, - "path_traversal": VulnerabilityType.PATH_TRAVERSAL, - "file_inclusion": VulnerabilityType.FILE_INCLUSION, - "ssrf": VulnerabilityType.SSRF, - "xxe": VulnerabilityType.XXE, - "deserialization": VulnerabilityType.DESERIALIZATION, - "auth_bypass": VulnerabilityType.AUTH_BYPASS, - "idor": VulnerabilityType.IDOR, - "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE, - "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, - "weak_crypto": VulnerabilityType.WEAK_CRYPTO, - "race_condition": VulnerabilityType.RACE_CONDITION, - "business_logic": VulnerabilityType.BUSINESS_LOGIC, - "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION, - } - - for finding in findings: - try: - # 确保 finding 是字典 - if not isinstance(finding, dict): - logger.warning(f"Skipping invalid finding (not a dict): {finding}") - continue - - db_finding = AgentFinding( - id=str(uuid.uuid4()), - task_id=self.task.id, - vulnerability_type=type_map.get( - finding.get("vulnerability_type", "other"), - VulnerabilityType.OTHER - ), - severity=severity_map.get( - finding.get("severity", "medium"), - VulnerabilitySeverity.MEDIUM - ), - title=finding.get("title", "Unknown"), - description=finding.get("description", ""), - file_path=finding.get("file_path"), - line_start=finding.get("line_start"), - line_end=finding.get("line_end"), - code_snippet=finding.get("code_snippet"), - source=finding.get("source"), - sink=finding.get("sink"), - suggestion=finding.get("suggestion") or finding.get("recommendation"), - is_verified=finding.get("is_verified", False), - confidence=finding.get("confidence", 0.5), - poc=finding.get("poc"), - status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW, - ) - - self.db.add(db_finding) - - except Exception as e: - logger.warning(f"Failed to save finding: {e}") - - try: - await self.db.commit() - logger.info(f"[Runner] Successfully saved {len(findings)} findings to database") - except Exception as e: - logger.error(f"Failed to commit findings: {e}") - await self.db.rollback() - - async def _update_task_status( - self, - status: AgentTaskStatus, - error: Optional[str] = None - ): - """更新任务状态""" - self.task.status = status - - if status == AgentTaskStatus.RUNNING: - self.task.started_at = datetime.now(timezone.utc) - elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: - self.task.finished_at = datetime.now(timezone.utc) - - if error: - self.task.error_message = error - - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task status: {e}") - - async def _update_task_phase(self, phase: AgentTaskPhase): - """更新任务阶段""" - self.task.current_phase = phase - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task phase: {e}") - - async def _update_task_summary( - self, - total_findings: int, - verified_count: int, - security_score: int, - ): - """更新任务摘要""" - self.task.total_findings = total_findings - self.task.verified_findings = verified_count - self.task.security_score = security_score - - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task summary: {e}") - - def _merge_findings_with_verification( - self, - findings: List[Dict], - verified_findings: List[Dict], - ) -> List[Dict]: - """ - 合并原始 findings 和验证结果 - - Args: - findings: 原始 findings 列表 - verified_findings: 验证后的 findings 列表 - - Returns: - 合并后的 findings 列表 - """ - # 创建验证结果的查找映射 - verified_map = {} - for vf in verified_findings: - if not isinstance(vf, dict): - continue - key = ( - vf.get("file_path", ""), - vf.get("line_start", 0), - vf.get("vulnerability_type", ""), - ) - verified_map[key] = vf - - merged = [] - seen_keys = set() - - # 首先处理原始 findings - for f in findings: - if not isinstance(f, dict): - continue - - key = ( - f.get("file_path", ""), - f.get("line_start", 0), - f.get("vulnerability_type", ""), - ) - - if key in verified_map: - # 使用验证后的版本(包含 is_verified, poc 等) - merged.append(verified_map[key]) - else: - # 保留原始 finding - merged.append(f) - - seen_keys.add(key) - - # 添加验证结果中的新发现(如果有) - for key, vf in verified_map.items(): - if key not in seen_keys: - merged.append(vf) - - return merged - - async def _cleanup(self): - """清理资源""" - try: - if self.sandbox_manager: - await self.sandbox_manager.cleanup() - await self.event_manager.close() - except Exception as e: - logger.warning(f"Cleanup error: {e}") - - -# 便捷函数 -async def run_agent_task( - db: AsyncSession, - task: AgentTask, - project_root: str, -) -> Dict[str, Any]: - """ - 运行 Agent 审计任务 - - Args: - db: 数据库会话 - task: Agent 任务 - project_root: 项目根目录 - - Returns: - 审计结果 - """ - runner = AgentRunner(db, task, project_root) - return await runner.run() - diff --git a/backend/tests/agent/test_integration.py b/backend/tests/agent/test_integration.py deleted file mode 100644 index d314404..0000000 --- a/backend/tests/agent/test_integration.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -Agent 集成测试 -测试完整的审计流程 -""" - -import pytest -import asyncio -import os -from unittest.mock import MagicMock, AsyncMock, patch -from datetime import datetime - -from app.services.agent.graph.runner import AgentRunner, LLMService -from app.services.agent.graph.audit_graph import AuditState, create_audit_graph -from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode -from app.services.agent.event_manager import EventManager, AgentEventEmitter - - -class TestLLMService: - """LLM 服务测试""" - - @pytest.mark.asyncio - async def test_llm_service_initialization(self): - """测试 LLM 服务初始化""" - with patch("app.core.config.settings") as mock_settings: - mock_settings.LLM_MODEL = "gpt-4o-mini" - mock_settings.LLM_API_KEY = "test-key" - - service = LLMService() - - assert service.model == "gpt-4o-mini" - - -class TestEventManager: - """事件管理器测试""" - - def test_event_manager_initialization(self): - """测试事件管理器初始化""" - manager = EventManager() - - assert manager._event_queues == {} - assert manager._event_callbacks == {} - - @pytest.mark.asyncio - async def test_event_emitter(self): - """测试事件发射器""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_info("Test message") - - assert emitter._sequence == 1 - - @pytest.mark.asyncio - async def test_event_emitter_phase_tracking(self): - """测试事件发射器阶段跟踪""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_phase_start("recon", "开始信息收集") - - assert emitter._current_phase == "recon" - - @pytest.mark.asyncio - async def test_event_emitter_task_complete(self): - """测试任务完成事件""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_task_complete(findings_count=5, duration_ms=1000) - - assert emitter._sequence == 1 - - -class TestAuditGraph: - """审计图测试""" - - def test_create_audit_graph(self, mock_event_emitter): - """测试创建审计图""" - # 创建模拟节点 - recon_node = MagicMock() - analysis_node = MagicMock() - verification_node = MagicMock() - report_node = MagicMock() - - graph = create_audit_graph( - recon_node=recon_node, - analysis_node=analysis_node, - verification_node=verification_node, - report_node=report_node, - ) - - assert graph is not None - - -class TestReconNode: - """Recon 节点测试""" - - @pytest.fixture - def recon_node_with_mock_agent(self, mock_event_emitter): - """创建带模拟 Agent 的 Recon 节点""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=True, - data={ - "tech_stack": {"languages": ["Python"]}, - "entry_points": [{"path": "src/app.py", "type": "api"}], - "high_risk_areas": ["src/sql_vuln.py"], - "dependencies": {}, - "initial_findings": [], - } - )) - - return ReconNode(mock_agent, mock_event_emitter) - - @pytest.mark.asyncio - async def test_recon_node_success(self, recon_node_with_mock_agent): - """测试 Recon 节点成功执行""" - state = { - "project_info": {"name": "Test"}, - "config": {}, - } - - result = await recon_node_with_mock_agent(state) - - assert "tech_stack" in result - assert "entry_points" in result - assert result["current_phase"] == "recon_complete" - - @pytest.mark.asyncio - async def test_recon_node_failure(self, mock_event_emitter): - """测试 Recon 节点失败处理""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=False, - error="Test error", - data=None, - )) - - node = ReconNode(mock_agent, mock_event_emitter) - - result = await node({ - "project_info": {}, - "config": {}, - }) - - assert "error" in result - assert result["current_phase"] == "error" - - -class TestAnalysisNode: - """Analysis 节点测试""" - - @pytest.fixture - def analysis_node_with_mock_agent(self, mock_event_emitter): - """创建带模拟 Agent 的 Analysis 节点""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=True, - data={ - "findings": [ - { - "id": "finding-1", - "title": "SQL Injection", - "severity": "high", - "vulnerability_type": "sql_injection", - "file_path": "src/sql_vuln.py", - "line_start": 10, - "description": "SQL injection vulnerability", - } - ], - "should_continue": False, - } - )) - - return AnalysisNode(mock_agent, mock_event_emitter) - - @pytest.mark.asyncio - async def test_analysis_node_success(self, analysis_node_with_mock_agent): - """测试 Analysis 节点成功执行""" - state = { - "project_info": {"name": "Test"}, - "tech_stack": {"languages": ["Python"]}, - "entry_points": [], - "high_risk_areas": ["src/sql_vuln.py"], - "config": {}, - "iteration": 0, - "findings": [], - } - - result = await analysis_node_with_mock_agent(state) - - assert "findings" in result - assert len(result["findings"]) > 0 - assert result["iteration"] == 1 - - -class TestIntegrationFlow: - """完整流程集成测试""" - - @pytest.mark.asyncio - async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task): - """测试完整审计流程(使用模拟)""" - # 这个测试验证整个流程的连接性 - - # 创建事件管理器 - event_manager = EventManager() - emitter = AgentEventEmitter(mock_task.id, event_manager) - - # 模拟 LLM 服务 - mock_llm = MagicMock() - mock_llm.chat_completion_raw = AsyncMock(return_value={ - "content": "Analysis complete", - "usage": {"total_tokens": 100}, - }) - - # 验证事件发射 - await emitter.emit_phase_start("init", "初始化") - await emitter.emit_info("测试消息") - await emitter.emit_phase_complete("init", "初始化完成") - - assert emitter._sequence == 3 - - @pytest.mark.asyncio - async def test_audit_state_typing(self): - """测试审计状态类型定义""" - state: AuditState = { - "project_root": "/tmp/test", - "project_info": {"name": "Test"}, - "config": {}, - "task_id": "test-id", - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "current_phase": "start", - "iteration": 0, - "max_iterations": 50, - "should_continue_analysis": False, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - assert state["current_phase"] == "start" - assert state["max_iterations"] == 50 - - -class TestToolIntegration: - """工具集成测试""" - - @pytest.mark.asyncio - async def test_tools_work_together(self, temp_project_dir): - """测试工具协同工作""" - from app.services.agent.tools import ( - FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool, - ) - - # 1. 列出文件 - list_tool = ListFilesTool(temp_project_dir) - list_result = await list_tool.execute(directory="src", recursive=False) - assert list_result.success is True - - # 2. 搜索关键代码 - search_tool = FileSearchTool(temp_project_dir) - search_result = await search_tool.execute(keyword="execute") - assert search_result.success is True - - # 3. 读取文件内容 - read_tool = FileReadTool(temp_project_dir) - read_result = await read_tool.execute(file_path="src/sql_vuln.py") - assert read_result.success is True - - # 4. 模式匹配 - pattern_tool = PatternMatchTool(temp_project_dir) - pattern_result = await pattern_tool.execute( - code=read_result.data, - file_path="src/sql_vuln.py", - language="python" - ) - assert pattern_result.success is True - - -class TestErrorHandling: - """错误处理测试""" - - @pytest.mark.asyncio - async def test_tool_error_handling(self, temp_project_dir): - """测试工具错误处理""" - from app.services.agent.tools import FileReadTool - - tool = FileReadTool(temp_project_dir) - - # 尝试读取不存在的文件 - result = await tool.execute(file_path="nonexistent/file.py") - - assert result.success is False - assert result.error is not None - - @pytest.mark.asyncio - async def test_agent_graceful_degradation(self, mock_event_emitter): - """测试 Agent 优雅降级""" - # 创建一个会失败的 Agent - mock_agent = MagicMock() - mock_agent.run = AsyncMock(side_effect=Exception("Simulated error")) - - node = ReconNode(mock_agent, mock_event_emitter) - - result = await node({ - "project_info": {}, - "config": {}, - }) - - # 应该返回错误状态而不是崩溃 - assert "error" in result - assert result["current_phase"] == "error" - - -class TestPerformance: - """性能测试""" - - @pytest.mark.asyncio - async def test_tool_response_time(self, temp_project_dir): - """测试工具响应时间""" - from app.services.agent.tools import ListFilesTool - import time - - tool = ListFilesTool(temp_project_dir) - - start = time.time() - await tool.execute(directory=".", recursive=True) - duration = time.time() - start - - # 工具应该在合理时间内响应 - assert duration < 5.0 # 5 秒内 - - @pytest.mark.asyncio - async def test_multiple_tool_calls(self, temp_project_dir): - """测试多次工具调用""" - from app.services.agent.tools import FileSearchTool - - tool = FileSearchTool(temp_project_dir) - - # 执行多次调用 - for _ in range(5): - result = await tool.execute(keyword="def") - assert result.success is True - - # 验证调用计数 - assert tool._call_count == 5 -