CodeReview/backend/app/services/agent/graph/runner.py

1043 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DeepAudit LangGraph Runner
基于 LangGraph 的 Agent 审计执行器
"""
import asyncio
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any, AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
from app.models.agent_task import (
AgentTask, AgentEvent, AgentFinding,
AgentTaskStatus, AgentTaskPhase, AgentEventType,
VulnerabilitySeverity, VulnerabilityType, FindingStatus,
)
from app.services.agent.event_manager import EventManager, AgentEventEmitter
from app.services.agent.tools import (
RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool,
PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool,
FileReadTool, FileSearchTool, ListFilesTool,
SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager,
SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool,
TruffleHogTool, OSVScannerTool,
)
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
from app.core.config import settings
from .audit_graph import AuditState, create_audit_graph
from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
logger = logging.getLogger(__name__)
# 🔥 使用系统统一的 LLMService支持用户配置
from app.services.llm.service import LLMService
class AgentRunner:
"""
DeepAudit LangGraph Agent Runner
基于 LangGraph 状态图的审计执行器
工作流:
START → Recon → Analysis ⟲ → Verification → Report → END
"""
def __init__(
self,
db: AsyncSession,
task: AgentTask,
project_root: str,
user_config: Optional[Dict[str, Any]] = None,
):
self.db = db
self.task = task
self.project_root = project_root
# 🔥 保存用户配置,供 RAG 初始化使用
self.user_config = user_config or {}
# 事件管理 - 传入 db_session_factory 以持久化事件
from app.db.session import async_session_factory
self.event_manager = EventManager(db_session_factory=async_session_factory)
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
# 🔥 CRITICAL: 立即创建事件队列,确保在 Agent 开始执行前队列就存在
# 这样即使前端 SSE 连接稍晚token 事件也不会丢失
self.event_manager.create_queue(task.id)
# 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取)
self.llm_service = LLMService(user_config=self.user_config)
# 工具集
self.tools: Dict[str, Any] = {}
# RAG 组件
self.retriever: Optional[CodeRetriever] = None
self.indexer: Optional[CodeIndexer] = None
# 沙箱
self.sandbox_manager: Optional[SandboxManager] = None
# LangGraph
self.graph: Optional[StateGraph] = None
self.checkpointer = MemorySaver()
# 状态
self._cancelled = False
self._running_task: Optional[asyncio.Task] = None
# Agent 引用(用于取消传播)
self._agents: List[Any] = []
# 流式处理器
self.stream_handler = StreamHandler(task.id)
def cancel(self):
"""取消任务"""
self._cancelled = True
# 🔥 取消所有 Agent
for agent in self._agents:
if hasattr(agent, 'cancel'):
agent.cancel()
logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}")
# 取消运行中的任务
if self._running_task and not self._running_task.done():
self._running_task.cancel()
logger.info(f"Task {self.task.id} cancellation requested")
@property
def is_cancelled(self) -> bool:
"""检查是否已取消"""
return self._cancelled
async def initialize(self):
"""初始化 Runner"""
await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...")
# 1. 初始化 RAG 系统
await self._initialize_rag()
# 2. 初始化工具
await self._initialize_tools()
# 3. 构建 LangGraph
await self._build_graph()
await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成")
async def _initialize_rag(self):
"""初始化 RAG 系统"""
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
try:
# 🔥 从用户配置中获取配置
# 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
user_llm_config = self.user_config.get('llmConfig', {})
user_other_config = self.user_config.get('otherConfig', {})
user_embedding_config = user_other_config.get('embedding_config', {})
# 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量
embedding_provider = (
user_embedding_config.get('provider') or
getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
)
# 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量
embedding_model = (
user_embedding_config.get('model') or
getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
)
# 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
embedding_api_key = (
user_embedding_config.get('api_key') or
user_llm_config.get('llmApiKey') or
getattr(settings, 'LLM_API_KEY', '') or
''
)
# 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量
embedding_base_url = (
user_embedding_config.get('base_url') or
user_llm_config.get('llmBaseUrl') or
getattr(settings, 'LLM_BASE_URL', None) or
None
)
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}")
await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}")
embedding_service = EmbeddingService(
provider=embedding_provider,
model=embedding_model,
api_key=embedding_api_key,
base_url=embedding_base_url,
)
self.indexer = CodeIndexer(
collection_name=f"project_{self.task.project_id}",
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
self.retriever = CodeRetriever(
collection_name=f"project_{self.task.project_id}",
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
except Exception as e:
logger.warning(f"RAG initialization failed: {e}")
await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}")
async def _initialize_tools(self):
"""初始化工具集"""
await self.event_emitter.emit_info("初始化 Agent 工具集...")
# 🔥 导入新工具
from app.services.agent.tools import (
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
# 多语言代码测试工具
PhpTestTool, PythonTestTool, JavaScriptTestTool, JavaTestTool,
GoTestTool, RubyTestTool, ShellTestTool, UniversalCodeTestTool,
# 漏洞验证专用工具
CommandInjectionTestTool, SqlInjectionTestTool, XssTestTool,
PathTraversalTestTool, SstiTestTool, DeserializationTestTool,
UniversalVulnTestTool,
# Kunlun-M 静态代码分析工具 (MIT License)
KunlunMTool, KunlunRuleListTool, KunlunPluginTool,
)
# 🔥 导入知识查询工具
from app.services.agent.knowledge import (
SecurityKnowledgeQueryTool,
GetVulnerabilityKnowledgeTool,
)
# 🔥 获取排除模式和目标文件
exclude_patterns = self.task.exclude_patterns or []
target_files = self.task.target_files or None
# ============ 🔥 提前初始化 SandboxManager供所有外部工具共享============
self.sandbox_manager = None
try:
from app.services.agent.tools.sandbox_tool import SandboxConfig
sandbox_config = SandboxConfig(
image=settings.SANDBOX_IMAGE,
memory_limit=settings.SANDBOX_MEMORY_LIMIT,
cpu_limit=settings.SANDBOX_CPU_LIMIT,
timeout=settings.SANDBOX_TIMEOUT,
network_mode=settings.SANDBOX_NETWORK_MODE,
)
self.sandbox_manager = SandboxManager(config=sandbox_config)
# 🔥 必须调用 initialize() 来连接 Docker
await self.sandbox_manager.initialize()
logger.info(f"✅ SandboxManager initialized early (Docker available: {self.sandbox_manager.is_available})")
except Exception as e:
logger.warning(f"❌ Early Sandbox Manager initialization failed: {e}")
import traceback
logger.warning(f"Traceback: {traceback.format_exc()}")
# 尝试创建默认管理器作为后备
try:
self.sandbox_manager = SandboxManager()
await self.sandbox_manager.initialize()
logger.info(f"⚠️ Created fallback SandboxManager (Docker available: {self.sandbox_manager.is_available})")
except Exception as e2:
logger.error(f"❌ Failed to create fallback SandboxManager: {e2}")
# ============ 基础工具(所有 Agent 共享)============
base_tools = {
"read_file": FileReadTool(self.project_root, exclude_patterns, target_files),
"list_files": ListFilesTool(self.project_root, exclude_patterns, target_files),
# 🔥 新增思考工具所有Agent可用
"think": ThinkTool(),
}
# ============ Recon Agent 专属工具 ============
# 职责:信息收集、项目结构分析、技术栈识别
# 🔥 新增外部工具也可用于Recon阶段的快速扫描
self.recon_tools = {
**base_tools,
"search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
# 🔥 外部安全工具(共享 SandboxManager 实例)
"semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager),
"bandit_scan": BanditTool(self.project_root, self.sandbox_manager),
"gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager),
"safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
"npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
}
# RAG 工具Recon 用于语义搜索)
if self.retriever:
self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
logger.info("✅ RAG 工具已注册到 Recon Agent")
else:
logger.warning("⚠️ RAG 未初始化rag_query 工具不可用")
# ============ Analysis Agent 专属工具 ============
# 职责:漏洞分析、代码审计、模式匹配
self.analysis_tools = {
**base_tools,
"search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
# 模式匹配和代码分析
"pattern_match": PatternMatchTool(self.project_root),
# TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败
# "code_analysis": CodeAnalysisTool(self.llm_service),
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
# 🔥 外部静态分析工具(共享 SandboxManager 实例)
"semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager),
"bandit_scan": BanditTool(self.project_root, self.sandbox_manager),
"gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager),
"trufflehog_scan": TruffleHogTool(self.project_root, self.sandbox_manager),
"npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager),
"safety_scan": SafetyTool(self.project_root, self.sandbox_manager),
"osv_scan": OSVScannerTool(self.project_root, self.sandbox_manager),
# 🔥 Kunlun-M 静态代码分析工具 (MIT License - https://github.com/LoRexxar/Kunlun-M)
"kunlun_scan": KunlunMTool(self.project_root),
"kunlun_list_rules": KunlunRuleListTool(self.project_root),
"kunlun_plugin": KunlunPluginTool(self.project_root),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
# 🔥 新增安全知识查询工具基于RAG
"query_security_knowledge": SecurityKnowledgeQueryTool(),
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
}
# RAG 工具Analysis 用于安全相关代码搜索)
if self.retriever:
self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文
logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)")
# ============ Verification Agent 专属工具 ============
# 职责漏洞验证、PoC 执行、误报排除
self.verification_tools = {
**base_tools,
# 验证工具 - 移除旧的 vulnerability_validation 和 dataflow_analysis强制使用沙箱
# 🔥 新增漏洞报告工具仅Verification可用- v2.1: 传递 project_root
"create_vulnerability_report": CreateVulnerabilityReportTool(self.project_root),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
}
# 🔥 注册沙箱工具(使用提前初始化的 SandboxManager
if self.sandbox_manager:
# 🔥 沙箱核心工具
self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
# 🔥 多语言代码测试工具
self.verification_tools["php_test"] = PhpTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["python_test"] = PythonTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["javascript_test"] = JavaScriptTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["java_test"] = JavaTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["go_test"] = GoTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["ruby_test"] = RubyTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["shell_test"] = ShellTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["universal_code_test"] = UniversalCodeTestTool(self.sandbox_manager, self.project_root)
# 🔥 漏洞验证专用工具
self.verification_tools["test_command_injection"] = CommandInjectionTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["test_sql_injection"] = SqlInjectionTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["test_xss"] = XssTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["test_path_traversal"] = PathTraversalTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["test_ssti"] = SstiTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["test_deserialization"] = DeserializationTestTool(self.sandbox_manager, self.project_root)
self.verification_tools["universal_vuln_test"] = UniversalVulnTestTool(self.sandbox_manager, self.project_root)
logger.info(f"✅ Sandbox tools initialized (Docker available: {self.sandbox_manager.is_available})")
else:
logger.error("❌ Sandbox tools NOT initialized due to critical manager failure")
logger.info(f"✅ Verification tools: {list(self.verification_tools.keys())}")
# 统计总工具数
total_tools = len(set(
list(self.recon_tools.keys()) +
list(self.analysis_tools.keys()) +
list(self.verification_tools.keys())
))
await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具")
async def _build_graph(self):
"""构建 LangGraph 审计图"""
await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...")
# 导入 Agent
from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
# 创建 Agent 实例(每个 Agent 使用专属工具集)
recon_agent = ReconAgent(
llm_service=self.llm_service,
tools=self.recon_tools, # Recon 专属工具
event_emitter=self.event_emitter,
)
analysis_agent = AnalysisAgent(
llm_service=self.llm_service,
tools=self.analysis_tools, # Analysis 专属工具
event_emitter=self.event_emitter,
)
verification_agent = VerificationAgent(
llm_service=self.llm_service,
tools=self.verification_tools, # Verification 专属工具
event_emitter=self.event_emitter,
)
# 🔥 保存 Agent 引用以便取消时传播信号
self._agents = [recon_agent, analysis_agent, verification_agent]
# 创建节点
recon_node = ReconNode(recon_agent, self.event_emitter)
analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
verification_node = VerificationNode(verification_agent, self.event_emitter)
report_node = ReportNode(None, self.event_emitter)
# 构建图
self.graph = create_audit_graph(
recon_node=recon_node,
analysis_node=analysis_node,
verification_node=verification_node,
report_node=report_node,
checkpointer=self.checkpointer,
)
await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成")
async def run(self) -> Dict[str, Any]:
"""
执行 LangGraph 审计
Returns:
最终状态
"""
final_state = {}
try:
async for event in self.run_with_streaming():
# 收集最终状态
if event.event_type == StreamEventType.TASK_COMPLETE:
final_state = event.data
elif event.event_type == StreamEventType.TASK_ERROR:
final_state = {"success": False, "error": event.data.get("error")}
except Exception as e:
logger.error(f"Agent run failed: {e}", exc_info=True)
final_state = {"success": False, "error": str(e)}
return final_state
async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]:
"""
带流式输出的审计执行
Yields:
StreamEvent: 流式事件(包含 LLM 思考、工具调用等)
"""
import time
start_time = time.time()
try:
# 初始化
await self.initialize()
# 更新任务状态
await self._update_task_status(AgentTaskStatus.RUNNING)
# 发射任务开始事件
yield StreamEvent(
event_type=StreamEventType.TASK_START,
sequence=self.stream_handler._next_sequence(),
data={"task_id": self.task.id, "message": "🚀 审计任务开始"},
)
# 1. 索引代码
await self._index_code()
if self._cancelled:
yield StreamEvent(
event_type=StreamEventType.TASK_CANCEL,
sequence=self.stream_handler._next_sequence(),
data={"message": "任务已取消"},
)
return
# 2. 收集项目信息
project_info = await self._collect_project_info()
# 3. 构建初始状态
task_config = {
"target_vulnerabilities": self.task.target_vulnerabilities or [],
"verification_level": self.task.verification_level or "sandbox",
"exclude_patterns": self.task.exclude_patterns or [],
"target_files": self.task.target_files or [],
"max_iterations": self.task.max_iterations or 50,
"timeout_seconds": self.task.timeout_seconds or 1800,
}
initial_state: AuditState = {
"project_root": self.project_root,
"project_info": project_info,
"config": task_config,
"task_id": self.task.id,
"tech_stack": {},
"entry_points": [],
"high_risk_areas": [],
"dependencies": {},
"findings": [],
"verified_findings": [],
"false_positives": [],
"_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新
"current_phase": "start",
"iteration": 0,
"max_iterations": self.task.max_iterations or 50,
"should_continue_analysis": False,
# 🔥 Agent 协作交接信息
"recon_handoff": None,
"analysis_handoff": None,
"verification_handoff": None,
"messages": [],
"events": [],
"summary": None,
"security_score": None,
"error": None,
}
# 4. 执行 LangGraph with astream_events
await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流")
run_config = {
"configurable": {
"thread_id": self.task.id,
}
}
final_state = None
# 使用 astream_events 获取详细事件流
try:
async for event in self.graph.astream_events(
initial_state,
config=run_config,
version="v2",
):
if self._cancelled:
break
# 处理 LangGraph 事件
stream_event = await self.stream_handler.process_langgraph_event(event)
if stream_event:
# 同步到 event_emitter 以持久化
await self._sync_stream_event_to_db(stream_event)
yield stream_event
# 更新最终状态
if event.get("event") == "on_chain_end":
output = event.get("data", {}).get("output")
if isinstance(output, dict):
final_state = output
except Exception as e:
# 如果 astream_events 不可用,回退到 astream
logger.warning(f"astream_events not available, falling back to astream: {e}")
async for event in self.graph.astream(initial_state, config=run_config):
if self._cancelled:
break
for node_name, node_output in event.items():
await self._handle_node_output(node_name, node_output)
# 发射节点事件
yield StreamEvent(
event_type=StreamEventType.NODE_END,
sequence=self.stream_handler._next_sequence(),
node_name=node_name,
data={"message": f"节点 {node_name} 完成"},
)
phase_map = {
"recon": AgentTaskPhase.RECONNAISSANCE,
"analysis": AgentTaskPhase.ANALYSIS,
"verification": AgentTaskPhase.VERIFICATION,
"report": AgentTaskPhase.REPORTING,
}
if node_name in phase_map:
await self._update_task_phase(phase_map[node_name])
final_state = node_output
# 5. 获取最终状态
# 🔥 CRITICAL FIX: 始终从 graph 获取完整的累积状态
# 因为每个节点只返回自己的输出findings 等字段是通过 operator.add 累积的
# 直接使用 node_output 会丢失之前节点累积的 findings
graph_state = self.graph.get_state(run_config)
if graph_state and graph_state.values:
# 合并完整状态和最后节点的输出
full_state = graph_state.values
if final_state:
# 保留最后节点的输出(如 summary, security_score
full_state = {**full_state, **final_state}
final_state = full_state
logger.info(f"[Runner] Got full state from graph with {len(final_state.get('findings', []))} findings")
elif not final_state:
final_state = {}
logger.warning("[Runner] No final state available from graph")
# 🔥 CRITICAL FIX: 如果有验证后的 findings 更新,使用它替换原始 findings
# 这是因为 LangGraph 的 operator.add 累积器不适合更新已有 findings
verified_findings_update = final_state.get("_verified_findings_update")
if verified_findings_update:
logger.info(f"[Runner] Using verified findings update: {len(verified_findings_update)} findings")
final_state["findings"] = verified_findings_update
else:
# 🔥 FALLBACK: 如果没有 _verified_findings_update尝试从 verified_findings 合并
findings = final_state.get("findings", [])
verified_findings = final_state.get("verified_findings", [])
if verified_findings and findings:
# 创建合并后的 findings 列表
merged_findings = self._merge_findings_with_verification(findings, verified_findings)
final_state["findings"] = merged_findings
logger.info(f"[Runner] Merged findings: {len(merged_findings)} total")
elif verified_findings and not findings:
# 如果只有 verified_findings直接使用
final_state["findings"] = verified_findings
logger.info(f"[Runner] Using verified_findings directly: {len(verified_findings)}")
logger.info(f"[Runner] Final findings count: {len(final_state.get('findings', []))}")
# 🔥 检查是否有错误
error = final_state.get("error")
if error:
# 检查是否是 LLM 认证错误
error_str = str(error)
if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str:
error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。"
logger.error(f"LLM authentication error: {error}")
else:
error_message = error_str
duration_ms = int((time.time() - start_time) * 1000)
# 标记任务为失败
await self._update_task_status(AgentTaskStatus.FAILED, error_message)
await self.event_emitter.emit_task_error(error_message)
yield StreamEvent(
event_type=StreamEventType.TASK_ERROR,
sequence=self.stream_handler._next_sequence(),
data={
"error": error_message,
"message": f"❌ 任务失败: {error_message}",
},
)
return
# 6. 保存发现
findings = final_state.get("findings", [])
await self._save_findings(findings)
# 发射发现事件
for finding in findings[:10]: # 限制数量
yield self.stream_handler.create_finding_event(
finding,
is_verified=finding.get("is_verified", False),
)
# 7. 更新任务摘要
summary = final_state.get("summary", {})
security_score = final_state.get("security_score", 100)
await self._update_task_summary(
total_findings=len(findings),
verified_count=len(final_state.get("verified_findings", [])),
security_score=security_score,
)
# 8. 完成
duration_ms = int((time.time() - start_time) * 1000)
await self._update_task_status(AgentTaskStatus.COMPLETED)
await self.event_emitter.emit_task_complete(
findings_count=len(findings),
duration_ms=duration_ms,
)
yield StreamEvent(
event_type=StreamEventType.TASK_COMPLETE,
sequence=self.stream_handler._next_sequence(),
data={
"findings_count": len(findings),
"verified_count": len(final_state.get("verified_findings", [])),
"security_score": security_score,
"duration_ms": duration_ms,
"message": f"✅ 审计完成!发现 {len(findings)} 个漏洞",
},
)
except asyncio.CancelledError:
await self._update_task_status(AgentTaskStatus.CANCELLED)
yield StreamEvent(
event_type=StreamEventType.TASK_CANCEL,
sequence=self.stream_handler._next_sequence(),
data={"message": "任务已取消"},
)
except Exception as e:
logger.error(f"LangGraph run failed: {e}", exc_info=True)
await self._update_task_status(AgentTaskStatus.FAILED, str(e))
await self.event_emitter.emit_error(str(e))
yield StreamEvent(
event_type=StreamEventType.TASK_ERROR,
sequence=self.stream_handler._next_sequence(),
data={"error": str(e), "message": f"❌ 审计失败: {e}"},
)
finally:
await self._cleanup()
async def _sync_stream_event_to_db(self, event: StreamEvent):
"""同步流式事件到数据库"""
try:
# 将 StreamEvent 转换为 AgentEventData
await self.event_manager.add_event(
task_id=self.task.id,
event_type=event.event_type.value,
sequence=event.sequence,
phase=event.phase,
message=event.data.get("message"),
tool_name=event.tool_name,
tool_input=event.data.get("input") or event.data.get("input_params"),
tool_output=event.data.get("output") or event.data.get("output_data"),
tool_duration_ms=event.data.get("duration_ms"),
metadata=event.data,
)
except Exception as e:
logger.warning(f"Failed to sync stream event to db: {e}")
async def _handle_node_output(self, node_name: str, output: Dict[str, Any]):
"""处理节点输出"""
# 发射节点事件
events = output.get("events", [])
for evt in events:
await self.event_emitter.emit_info(
f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}"
)
# 处理新发现
if node_name == "analysis":
new_findings = output.get("findings", [])
if new_findings:
for finding in new_findings[:5]: # 限制事件数量
await self.event_emitter.emit_finding(
title=finding.get("title", "Unknown"),
severity=finding.get("severity", "medium"),
file_path=finding.get("file_path"),
)
# 处理验证结果
if node_name == "verification":
verified = output.get("verified_findings", [])
for v in verified[:5]:
await self.event_emitter.emit_info(
f"✅ 已验证: {v.get('title', 'Unknown')}"
)
# 处理错误
if output.get("error"):
await self.event_emitter.emit_error(output["error"])
async def _index_code(self):
"""索引代码"""
if not self.indexer:
await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引")
return
await self._update_task_phase(AgentTaskPhase.INDEXING)
await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引")
try:
async for progress in self.indexer.index_directory(self.project_root):
if self._cancelled:
return
await self.event_emitter.emit_progress(
progress.processed_files,
progress.total_files,
f"正在索引: {progress.current_file or 'N/A'}"
)
await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成")
except Exception as e:
logger.warning(f"Code indexing failed: {e}")
await self.event_emitter.emit_warning(f"代码索引失败: {e}")
async def _collect_project_info(self) -> Dict[str, Any]:
"""收集项目信息"""
info = {
"name": self.task.project.name if self.task.project else "unknown",
"root": self.project_root,
"languages": [],
"file_count": 0,
}
try:
exclude_dirs = {
"node_modules", "__pycache__", ".git", "venv", ".venv",
"build", "dist", "target", ".idea", ".vscode",
}
for root, dirs, files in os.walk(self.project_root):
dirs[:] = [d for d in dirs if d not in exclude_dirs]
info["file_count"] += len(files)
lang_map = {
".py": "Python", ".js": "JavaScript", ".ts": "TypeScript",
".java": "Java", ".go": "Go", ".php": "PHP",
".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++",
}
for f in files:
ext = os.path.splitext(f)[1].lower()
if ext in lang_map and lang_map[ext] not in info["languages"]:
info["languages"].append(lang_map[ext])
except Exception as e:
logger.warning(f"Failed to collect project info: {e}")
return info
async def _save_findings(self, findings: List[Dict]):
"""保存发现到数据库"""
logger.info(f"[Runner] Saving {len(findings)} findings to database for task {self.task.id}")
if not findings:
logger.info("[Runner] No findings to save")
return
severity_map = {
"critical": VulnerabilitySeverity.CRITICAL,
"high": VulnerabilitySeverity.HIGH,
"medium": VulnerabilitySeverity.MEDIUM,
"low": VulnerabilitySeverity.LOW,
"info": VulnerabilitySeverity.INFO,
}
type_map = {
"sql_injection": VulnerabilityType.SQL_INJECTION,
"nosql_injection": VulnerabilityType.NOSQL_INJECTION,
"xss": VulnerabilityType.XSS,
"command_injection": VulnerabilityType.COMMAND_INJECTION,
"code_injection": VulnerabilityType.CODE_INJECTION,
"path_traversal": VulnerabilityType.PATH_TRAVERSAL,
"file_inclusion": VulnerabilityType.FILE_INCLUSION,
"ssrf": VulnerabilityType.SSRF,
"xxe": VulnerabilityType.XXE,
"deserialization": VulnerabilityType.DESERIALIZATION,
"auth_bypass": VulnerabilityType.AUTH_BYPASS,
"idor": VulnerabilityType.IDOR,
"sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE,
"hardcoded_secret": VulnerabilityType.HARDCODED_SECRET,
"weak_crypto": VulnerabilityType.WEAK_CRYPTO,
"race_condition": VulnerabilityType.RACE_CONDITION,
"business_logic": VulnerabilityType.BUSINESS_LOGIC,
"memory_corruption": VulnerabilityType.MEMORY_CORRUPTION,
}
for finding in findings:
try:
# 确保 finding 是字典
if not isinstance(finding, dict):
logger.warning(f"Skipping invalid finding (not a dict): {finding}")
continue
db_finding = AgentFinding(
id=str(uuid.uuid4()),
task_id=self.task.id,
vulnerability_type=type_map.get(
finding.get("vulnerability_type", "other"),
VulnerabilityType.OTHER
),
severity=severity_map.get(
finding.get("severity", "medium"),
VulnerabilitySeverity.MEDIUM
),
title=finding.get("title", "Unknown"),
description=finding.get("description", ""),
file_path=finding.get("file_path"),
line_start=finding.get("line_start"),
line_end=finding.get("line_end"),
code_snippet=finding.get("code_snippet"),
source=finding.get("source"),
sink=finding.get("sink"),
suggestion=finding.get("suggestion") or finding.get("recommendation"),
is_verified=finding.get("is_verified", False),
confidence=finding.get("confidence", 0.5),
poc=finding.get("poc"),
status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW,
)
self.db.add(db_finding)
except Exception as e:
logger.warning(f"Failed to save finding: {e}")
try:
await self.db.commit()
logger.info(f"[Runner] Successfully saved {len(findings)} findings to database")
except Exception as e:
logger.error(f"Failed to commit findings: {e}")
await self.db.rollback()
async def _update_task_status(
self,
status: AgentTaskStatus,
error: Optional[str] = None
):
"""更新任务状态"""
self.task.status = status
if status == AgentTaskStatus.RUNNING:
self.task.started_at = datetime.now(timezone.utc)
elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
self.task.finished_at = datetime.now(timezone.utc)
if error:
self.task.error_message = error
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task status: {e}")
async def _update_task_phase(self, phase: AgentTaskPhase):
"""更新任务阶段"""
self.task.current_phase = phase
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task phase: {e}")
async def _update_task_summary(
self,
total_findings: int,
verified_count: int,
security_score: int,
):
"""更新任务摘要"""
self.task.total_findings = total_findings
self.task.verified_findings = verified_count
self.task.security_score = security_score
try:
await self.db.commit()
except Exception as e:
logger.error(f"Failed to update task summary: {e}")
def _merge_findings_with_verification(
self,
findings: List[Dict],
verified_findings: List[Dict],
) -> List[Dict]:
"""
合并原始 findings 和验证结果
Args:
findings: 原始 findings 列表
verified_findings: 验证后的 findings 列表
Returns:
合并后的 findings 列表
"""
# 创建验证结果的查找映射
verified_map = {}
for vf in verified_findings:
if not isinstance(vf, dict):
continue
key = (
vf.get("file_path", ""),
vf.get("line_start", 0),
vf.get("vulnerability_type", ""),
)
verified_map[key] = vf
merged = []
seen_keys = set()
# 首先处理原始 findings
for f in findings:
if not isinstance(f, dict):
continue
key = (
f.get("file_path", ""),
f.get("line_start", 0),
f.get("vulnerability_type", ""),
)
if key in verified_map:
# 使用验证后的版本(包含 is_verified, poc 等)
merged.append(verified_map[key])
else:
# 保留原始 finding
merged.append(f)
seen_keys.add(key)
# 添加验证结果中的新发现(如果有)
for key, vf in verified_map.items():
if key not in seen_keys:
merged.append(vf)
return merged
async def _cleanup(self):
"""清理资源"""
try:
if self.sandbox_manager:
await self.sandbox_manager.cleanup()
await self.event_manager.close()
except Exception as e:
logger.warning(f"Cleanup error: {e}")
# 便捷函数
async def run_agent_task(
db: AsyncSession,
task: AgentTask,
project_root: str,
) -> Dict[str, Any]:
"""
运行 Agent 审计任务
Args:
db: 数据库会话
task: Agent 任务
project_root: 项目根目录
Returns:
审计结果
"""
runner = AgentRunner(db, task, project_root)
return await runner.run()