diff --git a/README.md b/README.md index 8d64351..1d45b61 100644 --- a/README.md +++ b/README.md @@ -428,7 +428,7 @@ DeepSeek-Coder · Codestral
**欢迎大家来和我交流探讨!无论是技术问题、功能建议还是合作意向,都期待与你沟通~** - +(项目开发、投资孵化等合作洽谈请通过邮箱联系) | 联系方式 | | |:---:|:---:| | 📧 **邮箱** | **lintsinghua@qq.com** | diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index a9052f6..b68b478 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -294,6 +294,7 @@ async def _execute_agent_task(task_id: str): other_config = (user_config or {}).get('otherConfig', {}) github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN + gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN # 解密SSH私钥 ssh_private_key = None @@ -313,6 +314,7 @@ async def _execute_agent_task(task_id: str): task.branch_name, github_token=github_token, gitlab_token=gitlab_token, + gitea_token=gitea_token, # 🔥 新增 ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥 event_emitter=event_emitter, # 🔥 新增 ) @@ -2226,6 +2228,7 @@ async def _get_project_root( branch_name: Optional[str] = None, github_token: Optional[str] = None, gitlab_token: Optional[str] = None, + gitea_token: Optional[str] = None, # 🔥 新增 ssh_private_key: Optional[str] = None, # 🔥 新增:SSH私钥(用于SSH认证) event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志 ) -> str: @@ -2242,6 +2245,7 @@ async def _get_project_root( branch_name: 分支名称(仓库项目使用,优先于 project.default_branch) github_token: GitHub 访问令牌(用于私有仓库) gitlab_token: GitLab 访问令牌(用于私有仓库) + gitea_token: Gitea 访问令牌(用于私有仓库) ssh_private_key: SSH私钥(用于SSH认证) event_emitter: 事件发送器(用于发送实时日志) @@ -2503,9 +2507,19 @@ async def _get_project_root( parsed.fragment )) await emit(f"🔐 使用 GitLab Token 认证") + elif repo_type == "gitea" and gitea_token: + auth_url = urlunparse(( + parsed.scheme, + f"{gitea_token}@{parsed.netloc}", + parsed.path, + parsed.params, + parsed.query, + parsed.fragment + )) + await emit(f"🔐 使用 Gitea Token 认证") elif is_ssh_url and ssh_private_key: await emit(f"🔐 使用 SSH Key 认证") - + for branch in branches_to_try: check_cancelled() diff --git a/backend/app/api/v1/endpoints/projects.py b/backend/app/api/v1/endpoints/projects.py index dac1138..e26143a 100644 --- a/backend/app/api/v1/endpoints/projects.py +++ b/backend/app/api/v1/endpoints/projects.py @@ -16,6 +16,7 @@ from app.db.session import get_db, AsyncSessionLocal from app.models.project import Project from app.models.user import User from app.models.audit import AuditTask, AuditIssue +from app.models.agent_task import AgentTask, AgentTaskStatus, AgentFinding from app.models.user_config import UserConfig import zipfile from app.services.scanner import scan_repo_task, get_github_files, get_gitlab_files, get_github_branches, get_gitlab_branches, get_gitea_branches, should_exclude, is_text_file @@ -161,27 +162,52 @@ async def get_stats( ) projects = projects_result.scalars().all() project_ids = [p.id for p in projects] - - # 只统计当前用户项目的任务 + + # 统计旧的 AuditTask tasks_result = await db.execute( select(AuditTask).where(AuditTask.project_id.in_(project_ids)) if project_ids else select(AuditTask).where(False) ) tasks = tasks_result.scalars().all() task_ids = [t.id for t in tasks] - - # 只统计当前用户任务的问题 + + # 统计旧的 AuditIssue issues_result = await db.execute( select(AuditIssue).where(AuditIssue.task_id.in_(task_ids)) if task_ids else select(AuditIssue).where(False) ) issues = issues_result.scalars().all() - + + # 🔥 同时统计新的 AgentTask + agent_tasks_result = await db.execute( + select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False) + ) + agent_tasks = agent_tasks_result.scalars().all() + agent_task_ids = [t.id for t in agent_tasks] + + # 🔥 统计 AgentFinding + agent_findings_result = await db.execute( + select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False) + ) + agent_findings = agent_findings_result.scalars().all() + + # 合并统计(旧任务 + 新 Agent 任务) + total_tasks = len(tasks) + len(agent_tasks) + completed_tasks = ( + len([t for t in tasks if t.status == "completed"]) + + len([t for t in agent_tasks if t.status == AgentTaskStatus.COMPLETED]) + ) + total_issues = len(issues) + len(agent_findings) + resolved_issues = ( + len([i for i in issues if i.status == "resolved"]) + + len([f for f in agent_findings if f.status == "resolved"]) + ) + return { "total_projects": len(projects), "active_projects": len([p for p in projects if p.is_active]), - "total_tasks": len(tasks), - "completed_tasks": len([t for t in tasks if t.status == "completed"]), - "total_issues": len(issues), - "resolved_issues": len([i for i in issues if i.status == "resolved"]), + "total_tasks": total_tasks, + "completed_tasks": completed_tasks, + "total_issues": total_issues, + "resolved_issues": resolved_issues, } @router.get("/{id}", response_model=ProjectResponse) diff --git a/backend/app/services/agent/__init__.py b/backend/app/services/agent/__init__.py index fee2a92..f169c80 100644 --- a/backend/app/services/agent/__init__.py +++ b/backend/app/services/agent/__init__.py @@ -1,29 +1,19 @@ """ DeepAudit Agent 服务模块 -基于 LangGraph 的 AI Agent 代码安全审计 +基于动态 Agent 树架构的 AI 代码安全审计 -架构升级版本 - 支持: -- 动态Agent树结构 -- 专业知识模块系统 -- Agent间通信机制 -- 完整状态管理 -- Think工具和漏洞报告工具 +架构: +- OrchestratorAgent 作为编排层,动态调度子 Agent +- ReconAgent 负责侦察和文件分析 +- AnalysisAgent 负责漏洞分析 +- VerificationAgent 负责验证发现 工作流: - START → Recon → Analysis ⟲ → Verification → Report → END - + START → Orchestrator → [Recon/Analysis/Verification] → Report → END + 支持动态创建子Agent进行专业化分析 """ -# 从 graph 模块导入主要组件 -from .graph import ( - AgentRunner, - run_agent_task, - LLMService, - AuditState, - create_audit_graph, -) - # 事件管理 from .event_manager import EventManager, AgentEventEmitter @@ -33,14 +23,14 @@ from .agents import ( OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, ) -# 🔥 新增:核心模块(状态管理、注册表、消息) +# 核心模块(状态管理、注册表、消息) from .core import ( AgentState, AgentStatus, AgentRegistry, agent_registry, AgentMessage, MessageType, MessagePriority, MessageBus, ) -# 🔥 新增:知识模块系统(基于RAG) +# 知识模块系统(基于RAG) from .knowledge import ( KnowledgeLoader, knowledge_loader, get_available_modules, get_module_content, @@ -48,7 +38,7 @@ from .knowledge import ( SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, ) -# 🔥 新增:协作工具 +# 协作工具 from .tools import ( ThinkTool, ReflectTool, CreateVulnerabilityReportTool, @@ -57,24 +47,15 @@ from .tools import ( WaitForMessageTool, AgentFinishTool, ) -# 🔥 新增:遥测模块 +# 遥测模块 from .telemetry import Tracer, get_global_tracer, set_global_tracer __all__ = [ - # 核心 Runner - "AgentRunner", - "run_agent_task", - "LLMService", - - # LangGraph - "AuditState", - "create_audit_graph", - # 事件管理 "EventManager", "AgentEventEmitter", - + # Agent 类 "BaseAgent", "AgentConfig", @@ -83,8 +64,8 @@ __all__ = [ "ReconAgent", "AnalysisAgent", "VerificationAgent", - - # 🔥 核心模块 + + # 核心模块 "AgentState", "AgentStatus", "AgentRegistry", @@ -93,8 +74,8 @@ __all__ = [ "MessageType", "MessagePriority", "MessageBus", - - # 🔥 知识模块(基于RAG) + + # 知识模块(基于RAG) "KnowledgeLoader", "knowledge_loader", "get_available_modules", @@ -103,8 +84,8 @@ __all__ = [ "security_knowledge_rag", "SecurityKnowledgeQueryTool", "GetVulnerabilityKnowledgeTool", - - # 🔥 协作工具 + + # 协作工具 "ThinkTool", "ReflectTool", "CreateVulnerabilityReportTool", @@ -114,10 +95,9 @@ __all__ = [ "ViewAgentGraphTool", "WaitForMessageTool", "AgentFinishTool", - - # 🔥 遥测模块 + + # 遥测模块 "Tracer", "get_global_tracer", "set_global_tracer", ] - diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index e0da612..bdc5188 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -1024,10 +1024,18 @@ class BaseAgent(ABC): elif chunk["type"] == "error": accumulated = chunk.get("accumulated", "") error_msg = chunk.get("error", "Unknown error") - logger.error(f"[{self.name}] Stream error: {error_msg}") - if accumulated: - total_tokens = chunk.get("usage", {}).get("total_tokens", 0) - else: + error_type = chunk.get("error_type", "unknown") + user_message = chunk.get("user_message", error_msg) + logger.error(f"[{self.name}] Stream error ({error_type}): {error_msg}") + + if chunk.get("usage"): + total_tokens = chunk["usage"].get("total_tokens", 0) + + # 使用特殊前缀标记 API 错误,让调用方能够识别 + # 格式:[API_ERROR:error_type] user_message + if error_type in ("rate_limit", "quota_exceeded", "authentication", "connection"): + accumulated = f"[API_ERROR:{error_type}] {user_message}" + elif not accumulated: accumulated = f"[系统错误: {error_msg}] 请重新思考并输出你的决策。" break diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index 73c5e41..1e3c8c1 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -284,7 +284,56 @@ Action Input: {{"参数": "值"}} # 重置空响应计数器 self._empty_retry_count = 0 - + + # 🔥 检查是否是 API 错误(而非格式错误) + if llm_output.startswith("[API_ERROR:"): + # 提取错误类型和消息 + match = re.match(r"\[API_ERROR:(\w+)\]\s*(.*)", llm_output) + if match: + error_type = match.group(1) + error_message = match.group(2) + + if error_type == "rate_limit": + # 速率限制 - 等待后重试 + api_retry_count = getattr(self, '_api_retry_count', 0) + 1 + self._api_retry_count = api_retry_count + if api_retry_count >= 3: + logger.error(f"[{self.name}] Too many rate limit errors, stopping") + await self.emit_event("error", f"API 速率限制重试次数过多: {error_message}") + break + logger.warning(f"[{self.name}] Rate limit hit, waiting before retry ({api_retry_count}/3)") + await self.emit_event("warning", f"API 速率限制,等待后重试 ({api_retry_count}/3)") + await asyncio.sleep(30) # 等待 30 秒后重试 + continue + + elif error_type == "quota_exceeded": + # 配额用尽 - 终止任务 + logger.error(f"[{self.name}] API quota exceeded: {error_message}") + await self.emit_event("error", f"API 配额已用尽: {error_message}") + break + + elif error_type == "authentication": + # 认证错误 - 终止任务 + logger.error(f"[{self.name}] API authentication error: {error_message}") + await self.emit_event("error", f"API 认证失败: {error_message}") + break + + elif error_type == "connection": + # 连接错误 - 重试 + api_retry_count = getattr(self, '_api_retry_count', 0) + 1 + self._api_retry_count = api_retry_count + if api_retry_count >= 3: + logger.error(f"[{self.name}] Too many connection errors, stopping") + await self.emit_event("error", f"API 连接错误重试次数过多: {error_message}") + break + logger.warning(f"[{self.name}] Connection error, retrying ({api_retry_count}/3)") + await self.emit_event("warning", f"API 连接错误,重试中 ({api_retry_count}/3)") + await asyncio.sleep(5) # 等待 5 秒后重试 + continue + + # 重置 API 重试计数器(成功获取响应后) + self._api_retry_count = 0 + # 解析 LLM 的决策 step = self._parse_llm_response(llm_output) diff --git a/backend/app/services/agent/graph/__init__.py b/backend/app/services/agent/graph/__init__.py deleted file mode 100644 index 5bc17d1..0000000 --- a/backend/app/services/agent/graph/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -LangGraph 工作流模块 -使用状态图构建混合 Agent 审计流程 -""" - -from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human -from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode -from .runner import AgentRunner, run_agent_task, LLMService - -__all__ = [ - # 状态和图 - "AuditState", - "create_audit_graph", - "create_audit_graph_with_human", - - # 节点 - "ReconNode", - "AnalysisNode", - "VerificationNode", - "ReportNode", - "HumanReviewNode", - - # Runner - "AgentRunner", - "run_agent_task", - "LLMService", -] - diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py deleted file mode 100644 index 4b0eada..0000000 --- a/backend/app/services/agent/graph/audit_graph.py +++ /dev/null @@ -1,677 +0,0 @@ -""" -DeepAudit 审计工作流图 - LLM 驱动版 -使用 LangGraph 构建 LLM 驱动的 Agent 协作流程 - -重要改变:路由决策由 LLM 参与,而不是硬编码条件! -""" - -from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal -from datetime import datetime -import operator -import logging -import json - -from langgraph.graph import StateGraph, END -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt import ToolNode - -logger = logging.getLogger(__name__) - - -# ============ 状态定义 ============ - -class Finding(TypedDict): - """漏洞发现""" - id: str - vulnerability_type: str - severity: str - title: str - description: str - file_path: Optional[str] - line_start: Optional[int] - code_snippet: Optional[str] - is_verified: bool - confidence: float - source: str - - -class AuditState(TypedDict): - """ - 审计状态 - 在整个工作流中传递和更新 - """ - # 输入 - project_root: str - project_info: Dict[str, Any] - config: Dict[str, Any] - task_id: str - - # Recon 阶段输出 - tech_stack: Dict[str, Any] - entry_points: List[Dict[str, Any]] - high_risk_areas: List[str] - dependencies: Dict[str, Any] - - # Analysis 阶段输出 - findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现 - - # Verification 阶段输出 - verified_findings: List[Finding] - false_positives: List[str] - # 🔥 NEW: 验证后的完整 findings(用于替换原始 findings) - _verified_findings_update: Optional[List[Finding]] - - # 控制流 - 🔥 关键:LLM 可以设置这些来影响路由 - current_phase: str - iteration: int - max_iterations: int - should_continue_analysis: bool - - # 🔥 新增:LLM 的路由决策 - llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end" - llm_routing_reason: Optional[str] # LLM 的决策理由 - - # 🔥 新增:Agent 间协作的任务交接信息 - recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接 - analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接 - verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接 - - # 消息和事件 - messages: Annotated[List[Dict], operator.add] - events: Annotated[List[Dict], operator.add] - - # 最终输出 - summary: Optional[Dict[str, Any]] - security_score: Optional[int] - error: Optional[str] - - -# ============ LLM 路由决策器 ============ - -class LLMRouter: - """ - LLM 路由决策器 - 让 LLM 来决定下一步应该做什么 - """ - - def __init__(self, llm_service): - self.llm_service = llm_service - - async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]: - """Recon 后让 LLM 决定下一步""" - entry_points = state.get("entry_points", []) - high_risk_areas = state.get("high_risk_areas", []) - tech_stack = state.get("tech_stack", {}) - initial_findings = state.get("findings", []) - - prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。 - -## 信息收集结果 -- 入口点数量: {len(entry_points)} -- 高风险区域: {high_risk_areas[:10]} -- 技术栈: {tech_stack} -- 初步发现: {len(initial_findings)} 个 - -## 选项 -1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时) -2. "end" - 结束审计(仅当没有任何可分析内容时) - -请返回 JSON 格式: -{{"action": "analysis或end", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - # 提取 JSON - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if entry_points or high_risk_areas: - return {"action": "analysis", "reason": "有可分析内容"} - return {"action": "end", "reason": "没有发现入口点或高风险区域"} - - async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]: - """Analysis 后让 LLM 决定下一步""" - findings = state.get("findings", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - # 统计发现 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - for f in findings: - # 跳过非字典类型的 finding - if not isinstance(f, dict): - continue - sev = f.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。 - -## 分析结果 -- 总发现数: {len(findings)} -- 严重程度分布: {severity_counts} -- 当前迭代: {iteration}/{max_iterations} - -## 选项 -1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时) -2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时) -3. "report" - 生成报告(推荐:没有发现或已充分分析时) - -请返回 JSON 格式: -{{"action": "verification/analysis/report", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if not findings: - return {"action": "report", "reason": "没有发现漏洞"} - if len(findings) >= 3 or iteration >= max_iterations: - return {"action": "verification", "reason": "有足够的发现需要验证"} - return {"action": "analysis", "reason": "发现较少,继续分析"} - - async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]: - """Verification 后让 LLM 决定下一步""" - verified_findings = state.get("verified_findings", []) - false_positives = state.get("false_positives", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。 - -## 验证结果 -- 已确认漏洞: {len(verified_findings)} -- 误报数量: {len(false_positives)} -- 当前迭代: {iteration}/{max_iterations} - -## 选项 -1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时) -2. "report" - 生成最终报告(推荐:验证完成时) - -请返回 JSON 格式: -{{"action": "analysis/report", "reason": "决策理由"}}""" - - try: - response = await self.llm_service.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, - {"role": "user", "content": prompt}, - ], - # 🔥 不传递 temperature 和 max_tokens,使用用户配置 - ) - - content = response.get("content", "") - import re - json_match = re.search(r'\{.*\}', content, re.DOTALL) - if json_match: - result = json.loads(json_match.group()) - return result - except Exception as e: - logger.warning(f"LLM routing decision failed: {e}") - - # 默认决策 - if len(false_positives) > len(verified_findings) and iteration < max_iterations: - return {"action": "analysis", "reason": "误报率较高,需要重新分析"} - return {"action": "report", "reason": "验证完成,生成报告"} - - -# ============ 路由函数 (结合 LLM 决策) ============ - -def route_after_recon(state: AuditState) -> Literal["analysis", "end"]: - """ - Recon 后的路由决策 - 优先使用 LLM 的决策,否则使用默认逻辑 - """ - # 🔥 检查是否有错误 - if state.get("error") or state.get("current_phase") == "error": - logger.error(f"Recon phase has error, routing to end: {state.get('error')}") - return "end" - - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "end": - return "end" - return "analysis" - - # 默认逻辑(作为 fallback) - if not state.get("entry_points") and not state.get("high_risk_areas"): - return "end" - return "analysis" - - -def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]: - """ - Analysis 后的路由决策 - 优先使用 LLM 的决策 - """ - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "verification": - return "verification" - elif llm_action == "analysis": - return "analysis" - elif llm_action == "report": - return "report" - - # 默认逻辑 - findings = state.get("findings", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - should_continue = state.get("should_continue_analysis", False) - - if not findings: - return "report" - - if should_continue and iteration < max_iterations: - return "analysis" - - return "verification" - - -def route_after_verification(state: AuditState) -> Literal["analysis", "report"]: - """ - Verification 后的路由决策 - 优先使用 LLM 的决策 - """ - # 检查 LLM 是否有决策 - llm_action = state.get("llm_next_action") - if llm_action: - logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") - if llm_action == "analysis": - return "analysis" - return "report" - - # 默认逻辑 - false_positives = state.get("false_positives", []) - iteration = state.get("iteration", 0) - max_iterations = state.get("max_iterations", 3) - - if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations: - return "analysis" - - return "report" - - -# ============ 创建审计图 ============ - -def create_audit_graph( - recon_node, - analysis_node, - verification_node, - report_node, - checkpointer: Optional[MemorySaver] = None, - llm_service=None, # 用于 LLM 路由决策 -) -> StateGraph: - """ - 创建审计工作流图 - - Args: - recon_node: 信息收集节点 - analysis_node: 漏洞分析节点 - verification_node: 漏洞验证节点 - report_node: 报告生成节点 - checkpointer: 检查点存储器 - llm_service: LLM 服务(用于路由决策) - - Returns: - 编译后的 StateGraph - - 工作流结构: - - START - │ - ▼ - ┌──────┐ - │Recon │ 信息收集 (LLM 驱动) - └──┬───┘ - │ LLM 决定 - ▼ - ┌──────────┐ - │ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环) - └────┬─────┘ │ - │ LLM 决定 │ - ▼ │ - ┌────────────┐ │ - │Verification│────┘ 漏洞验证 (LLM 驱动,可回溯) - └─────┬──────┘ - │ LLM 决定 - ▼ - ┌──────────┐ - │ Report │ 报告生成 - └────┬─────┘ - │ - ▼ - END - """ - - # 创建状态图 - workflow = StateGraph(AuditState) - - # 如果有 LLM 服务,创建路由决策器 - llm_router = LLMRouter(llm_service) if llm_service else None - - # 包装节点以添加 LLM 路由决策 - async def recon_with_routing(state): - result = await recon_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_recon({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - async def analysis_with_routing(state): - result = await analysis_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_analysis({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - async def verification_with_routing(state): - result = await verification_node(state) - - # LLM 决定下一步 - if llm_router: - decision = await llm_router.decide_after_verification({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - - return result - - # 添加节点 - if llm_router: - workflow.add_node("recon", recon_with_routing) - workflow.add_node("analysis", analysis_with_routing) - workflow.add_node("verification", verification_with_routing) - else: - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) - workflow.add_node("verification", verification_node) - - workflow.add_node("report", report_node) - - # 设置入口点 - workflow.set_entry_point("recon") - - # 添加条件边 - workflow.add_conditional_edges( - "recon", - route_after_recon, - { - "analysis": "analysis", - "end": END, - } - ) - - workflow.add_conditional_edges( - "analysis", - route_after_analysis, - { - "verification": "verification", - "analysis": "analysis", - "report": "report", - } - ) - - workflow.add_conditional_edges( - "verification", - route_after_verification, - { - "analysis": "analysis", - "report": "report", - } - ) - - # Report -> END - workflow.add_edge("report", END) - - # 编译图 - if checkpointer: - return workflow.compile(checkpointer=checkpointer) - else: - return workflow.compile() - - -# ============ 带人机协作的审计图 ============ - -def create_audit_graph_with_human( - recon_node, - analysis_node, - verification_node, - report_node, - human_review_node, - checkpointer: Optional[MemorySaver] = None, - llm_service=None, -) -> StateGraph: - """ - 创建带人机协作的审计工作流图 - - 在验证阶段后增加人工审核节点 - """ - - workflow = StateGraph(AuditState) - llm_router = LLMRouter(llm_service) if llm_service else None - - # 包装节点 - async def recon_with_routing(state): - result = await recon_node(state) - if llm_router: - decision = await llm_router.decide_after_recon({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - return result - - async def analysis_with_routing(state): - result = await analysis_node(state) - if llm_router: - decision = await llm_router.decide_after_analysis({**state, **result}) - result["llm_next_action"] = decision.get("action") - result["llm_routing_reason"] = decision.get("reason") - return result - - # 添加节点 - if llm_router: - workflow.add_node("recon", recon_with_routing) - workflow.add_node("analysis", analysis_with_routing) - else: - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) - - workflow.add_node("verification", verification_node) - workflow.add_node("human_review", human_review_node) - workflow.add_node("report", report_node) - - workflow.set_entry_point("recon") - - workflow.add_conditional_edges( - "recon", - route_after_recon, - {"analysis": "analysis", "end": END} - ) - - workflow.add_conditional_edges( - "analysis", - route_after_analysis, - { - "verification": "verification", - "analysis": "analysis", - "report": "report", - } - ) - - # Verification -> Human Review - workflow.add_edge("verification", "human_review") - - # Human Review 后的路由 - def route_after_human(state: AuditState) -> Literal["analysis", "report"]: - if state.get("should_continue_analysis"): - return "analysis" - return "report" - - workflow.add_conditional_edges( - "human_review", - route_after_human, - {"analysis": "analysis", "report": "report"} - ) - - workflow.add_edge("report", END) - - if checkpointer: - return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"]) - else: - return workflow.compile() - - -# ============ 执行器 ============ - -class AuditGraphRunner: - """ - 审计图执行器 - 封装 LangGraph 工作流的执行 - """ - - def __init__( - self, - graph: StateGraph, - event_emitter=None, - ): - self.graph = graph - self.event_emitter = event_emitter - - async def run( - self, - project_root: str, - project_info: Dict[str, Any], - config: Dict[str, Any], - task_id: str, - ) -> Dict[str, Any]: - """ - 执行审计工作流 - """ - # 初始状态 - initial_state: AuditState = { - "project_root": project_root, - "project_info": project_info, - "config": config, - "task_id": task_id, - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新 - "current_phase": "start", - "iteration": 0, - "max_iterations": config.get("max_iterations", 3), - "should_continue_analysis": False, - "llm_next_action": None, - "llm_routing_reason": None, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - run_config = { - "configurable": { - "thread_id": task_id, - } - } - - try: - async for event in self.graph.astream(initial_state, config=run_config): - if self.event_emitter: - for node_name, node_state in event.items(): - await self.event_emitter.emit_info( - f"节点 {node_name} 完成" - ) - - # 发射 LLM 路由决策事件 - if node_state.get("llm_routing_reason"): - await self.event_emitter.emit_info( - f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}" - ) - - if node_name == "analysis" and node_state.get("findings"): - new_findings = node_state["findings"] - await self.event_emitter.emit_info( - f"发现 {len(new_findings)} 个潜在漏洞" - ) - - final_state = self.graph.get_state(run_config) - return final_state.values - - except Exception as e: - logger.error(f"Graph execution failed: {e}", exc_info=True) - raise - - async def run_with_human_review( - self, - initial_state: AuditState, - human_feedback_callback, - ) -> Dict[str, Any]: - """带人机协作的执行""" - run_config = { - "configurable": { - "thread_id": initial_state["task_id"], - } - } - - async for event in self.graph.astream(initial_state, config=run_config): - pass - - current_state = self.graph.get_state(run_config) - - if current_state.next == ("human_review",): - human_decision = await human_feedback_callback(current_state.values) - - updated_state = { - **current_state.values, - "should_continue_analysis": human_decision.get("continue_analysis", False), - } - - async for event in self.graph.astream(updated_state, config=run_config): - pass - - return self.graph.get_state(run_config).values diff --git a/backend/app/services/agent/graph/nodes.py b/backend/app/services/agent/graph/nodes.py deleted file mode 100644 index 0f8b034..0000000 --- a/backend/app/services/agent/graph/nodes.py +++ /dev/null @@ -1,556 +0,0 @@ -""" -LangGraph 节点实现 -每个节点封装一个 Agent 的执行逻辑 - -协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察 -""" - -from typing import Dict, Any, List, Optional -import logging - -logger = logging.getLogger(__name__) - -# 延迟导入避免循环依赖 -def get_audit_state_type(): - from .audit_graph import AuditState - return AuditState - - -class BaseNode: - """节点基类""" - - def __init__(self, agent=None, event_emitter=None): - self.agent = agent - self.event_emitter = event_emitter - - async def emit_event(self, event_type: str, message: str, **kwargs): - """发射事件""" - if self.event_emitter: - try: - await self.event_emitter.emit_info(message) - except Exception as e: - logger.warning(f"Failed to emit event: {e}") - - def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str): - """从状态中提取前序 Agent 的 handoff""" - handoff_data = state.get(f"{from_phase}_handoff") - if handoff_data: - from ..agents.base import TaskHandoff - return TaskHandoff.from_dict(handoff_data) - return None - - -class ReconNode(BaseNode): - """ - 信息收集节点 - - 输入: project_root, project_info, config - 输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行信息收集""" - await self.emit_event("phase_start", "🔍 开始信息收集阶段") - - try: - # 调用 Recon Agent - result = await self.agent.run({ - "project_info": state["project_info"], - "config": state["config"], - }) - - if result.success and result.data: - data = result.data - - # 🔥 创建交接信息给 Analysis Agent - handoff = self.agent.create_handoff( - to_agent="Analysis", - summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。", - key_findings=data.get("initial_findings", []), - suggested_actions=[ - { - "type": "deep_analysis", - "description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}", - "priority": "high", - }, - { - "type": "entry_point_audit", - "description": "审计所有入口点的输入验证", - "priority": "high", - }, - ], - attention_points=[ - f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}", - f"主要语言: {data.get('tech_stack', {}).get('languages', [])}", - ], - priority_areas=data.get("high_risk_areas", [])[:10], - context_data={ - "tech_stack": data.get("tech_stack", {}), - "entry_points": data.get("entry_points", []), - "dependencies": data.get("dependencies", {}), - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点" - ) - - return { - "tech_stack": data.get("tech_stack", {}), - "entry_points": data.get("entry_points", []), - "high_risk_areas": data.get("high_risk_areas", []), - "dependencies": data.get("dependencies", {}), - "current_phase": "recon_complete", - "findings": data.get("initial_findings", []), - # 🔥 保存交接信息 - "recon_handoff": handoff.to_dict(), - "events": [{ - "type": "recon_complete", - "data": { - "entry_points_count": len(data.get("entry_points", [])), - "high_risk_areas_count": len(data.get("high_risk_areas", [])), - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "error": result.error or "Recon failed", - "current_phase": "error", - } - - except Exception as e: - logger.error(f"Recon node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class AnalysisNode(BaseNode): - """ - 漏洞分析节点 - - 输入: tech_stack, entry_points, high_risk_areas, recon_handoff - 输出: findings (累加), should_continue_analysis, analysis_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行漏洞分析""" - iteration = state.get("iteration", 0) + 1 - - await self.emit_event( - "phase_start", - f"🔬 开始漏洞分析阶段 (迭代 {iteration})" - ) - - try: - # 🔥 提取 Recon 的交接信息 - recon_handoff = self._extract_handoff_from_state(state, "recon") - if recon_handoff: - self.agent.receive_handoff(recon_handoff) - await self.emit_event( - "handoff_received", - f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..." - ) - - # 构建分析输入 - analysis_input = { - "phase_name": "analysis", - "project_info": state["project_info"], - "config": state["config"], - "plan": { - "high_risk_areas": state.get("high_risk_areas", []), - }, - "previous_results": { - "recon": { - "data": { - "tech_stack": state.get("tech_stack", {}), - "entry_points": state.get("entry_points", []), - "high_risk_areas": state.get("high_risk_areas", []), - } - } - }, - # 🔥 传递交接信息 - "handoff": recon_handoff, - } - - # 调用 Analysis Agent - result = await self.agent.run(analysis_input) - - if result.success and result.data: - new_findings = result.data.get("findings", []) - logger.info(f"[AnalysisNode] Agent returned {len(new_findings)} findings") - - # 判断是否需要继续分析 - should_continue = ( - len(new_findings) >= 5 and - iteration < state.get("max_iterations", 3) - ) - - # 🔥 创建交接信息给 Verification Agent - # 统计严重程度 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - for f in new_findings: - if isinstance(f, dict): - sev = f.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - handoff = self.agent.create_handoff( - to_agent="Verification", - summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})", - key_findings=new_findings[:20], # 传递前20个发现 - suggested_actions=[ - { - "type": "verify_critical", - "description": "优先验证 Critical 和 High 级别的漏洞", - "priority": "critical", - }, - { - "type": "poc_generation", - "description": "为确认的漏洞生成 PoC", - "priority": "high", - }, - ], - attention_points=[ - f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证", - f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证", - "注意检查是否有误报,特别是静态分析工具的结果", - ], - priority_areas=[ - f.get("file_path", "") for f in new_findings - if f.get("severity") in ["critical", "high"] - ][:10], - context_data={ - "severity_distribution": severity_counts, - "total_findings": len(new_findings), - "iteration": iteration, - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞" - ) - - return { - "findings": new_findings, - "iteration": iteration, - "should_continue_analysis": should_continue, - "current_phase": "analysis_complete", - # 🔥 保存交接信息 - "analysis_handoff": handoff.to_dict(), - "events": [{ - "type": "analysis_iteration", - "data": { - "iteration": iteration, - "findings_count": len(new_findings), - "severity_distribution": severity_counts, - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "iteration": iteration, - "should_continue_analysis": False, - "current_phase": "analysis_complete", - } - - except Exception as e: - logger.error(f"Analysis node failed: {e}", exc_info=True) - return { - "error": str(e), - "should_continue_analysis": False, - "current_phase": "error", - } - - -class VerificationNode(BaseNode): - """ - 漏洞验证节点 - - 输入: findings, analysis_handoff - 输出: verified_findings, false_positives, verification_handoff - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """执行漏洞验证""" - findings = state.get("findings", []) - logger.info(f"[VerificationNode] Received {len(findings)} findings to verify") - - if not findings: - return { - "verified_findings": [], - "false_positives": [], - "current_phase": "verification_complete", - } - - await self.emit_event( - "phase_start", - f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)" - ) - - try: - # 🔥 提取 Analysis 的交接信息 - analysis_handoff = self._extract_handoff_from_state(state, "analysis") - if analysis_handoff: - self.agent.receive_handoff(analysis_handoff) - await self.emit_event( - "handoff_received", - f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..." - ) - - # 构建验证输入 - verification_input = { - "previous_results": { - "analysis": { - "data": { - "findings": findings, - } - } - }, - "config": state["config"], - # 🔥 传递交接信息 - "handoff": analysis_handoff, - } - - # 调用 Verification Agent - result = await self.agent.run(verification_input) - - if result.success and result.data: - all_verified_findings = result.data.get("findings", []) - verified = [f for f in all_verified_findings if f.get("is_verified")] - false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings - if f.get("verdict") == "false_positive"] - - # 🔥 CRITICAL FIX: 用验证结果更新原始 findings - # 创建 findings 的更新映射,基于 (file_path, line_start, vulnerability_type) - verified_map = {} - for vf in all_verified_findings: - key = ( - vf.get("file_path", ""), - vf.get("line_start", 0), - vf.get("vulnerability_type", ""), - ) - verified_map[key] = vf - - # 合并验证结果到原始 findings - updated_findings = [] - seen_keys = set() - - # 首先处理原始 findings,用验证结果更新 - for f in findings: - if not isinstance(f, dict): - continue - key = ( - f.get("file_path", ""), - f.get("line_start", 0), - f.get("vulnerability_type", ""), - ) - if key in verified_map: - # 使用验证后的版本 - updated_findings.append(verified_map[key]) - seen_keys.add(key) - else: - # 保留原始(未验证) - updated_findings.append(f) - seen_keys.add(key) - - # 添加验证结果中的新发现(如果有) - for key, vf in verified_map.items(): - if key not in seen_keys: - updated_findings.append(vf) - - logger.info(f"[VerificationNode] Updated findings: {len(updated_findings)} total, {len(verified)} verified") - - # 🔥 创建交接信息给 Report 节点 - handoff = self.agent.create_handoff( - to_agent="Report", - summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。", - key_findings=verified, - suggested_actions=[ - { - "type": "generate_report", - "description": "生成详细的安全审计报告", - "priority": "high", - }, - { - "type": "remediation_plan", - "description": "为确认的漏洞制定修复计划", - "priority": "high", - }, - ], - attention_points=[ - f"共 {len(verified)} 个漏洞已确认存在", - f"共 {len(false_pos)} 个误报已排除", - "建议按严重程度优先修复 Critical 和 High 级别漏洞", - ], - context_data={ - "verified_count": len(verified), - "false_positive_count": len(false_pos), - "total_analyzed": len(findings), - "verification_rate": len(verified) / len(findings) if findings else 0, - }, - ) - - await self.emit_event( - "phase_complete", - f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报" - ) - - return { - # 🔥 CRITICAL: 返回更新后的 findings,这会替换状态中的 findings - # 注意:由于 LangGraph 使用 operator.add,我们需要在 runner 中处理合并 - # 这里我们返回 _verified_findings_update 作为特殊字段 - "_verified_findings_update": updated_findings, - "verified_findings": verified, - "false_positives": false_pos, - "current_phase": "verification_complete", - # 🔥 保存交接信息 - "verification_handoff": handoff.to_dict(), - "events": [{ - "type": "verification_complete", - "data": { - "verified_count": len(verified), - "false_positive_count": len(false_pos), - "total_findings": len(updated_findings), - "handoff_summary": handoff.summary, - } - }], - } - else: - return { - "verified_findings": [], - "false_positives": [], - "current_phase": "verification_complete", - } - - except Exception as e: - logger.error(f"Verification node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class ReportNode(BaseNode): - """ - 报告生成节点 - - 输入: all state - 输出: summary, security_score - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """生成审计报告""" - await self.emit_event("phase_start", "📊 生成审计报告") - - try: - # 🔥 CRITICAL FIX: 优先使用验证后的 findings 更新 - findings = state.get("_verified_findings_update") or state.get("findings", []) - verified = state.get("verified_findings", []) - false_positives = state.get("false_positives", []) - - logger.info(f"[ReportNode] State contains {len(findings)} findings, {len(verified)} verified") - - # 统计漏洞分布 - severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} - type_counts = {} - - for finding in findings: - # 跳过非字典类型的 finding(防止数据格式异常) - if not isinstance(finding, dict): - logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}") - continue - - sev = finding.get("severity", "medium") - severity_counts[sev] = severity_counts.get(sev, 0) + 1 - - vtype = finding.get("vulnerability_type", "other") - type_counts[vtype] = type_counts.get(vtype, 0) + 1 - - # 计算安全评分 - base_score = 100 - deductions = ( - severity_counts["critical"] * 25 + - severity_counts["high"] * 15 + - severity_counts["medium"] * 8 + - severity_counts["low"] * 3 - ) - security_score = max(0, base_score - deductions) - - # 生成摘要 - summary = { - "total_findings": len(findings), - "verified_count": len(verified), - "false_positive_count": len(false_positives), - "severity_distribution": severity_counts, - "vulnerability_types": type_counts, - "tech_stack": state.get("tech_stack", {}), - "entry_points_analyzed": len(state.get("entry_points", [])), - "high_risk_areas": state.get("high_risk_areas", []), - "iterations": state.get("iteration", 1), - } - - await self.emit_event( - "phase_complete", - f"报告生成完成: 安全评分 {security_score}/100" - ) - - return { - "summary": summary, - "security_score": security_score, - "current_phase": "complete", - "events": [{ - "type": "audit_complete", - "data": { - "security_score": security_score, - "total_findings": len(findings), - "verified_count": len(verified), - } - }], - } - - except Exception as e: - logger.error(f"Report node failed: {e}", exc_info=True) - return { - "error": str(e), - "current_phase": "error", - } - - -class HumanReviewNode(BaseNode): - """ - 人工审核节点 - - 在此节点暂停,等待人工反馈 - """ - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - 人工审核节点 - - 这个节点会被 interrupt_before 暂停 - 用户可以: - 1. 确认发现 - 2. 标记误报 - 3. 请求重新分析 - """ - await self.emit_event( - "human_review", - f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)" - ) - - # 返回当前状态,不做修改 - # 人工反馈会通过 update_state 传入 - return { - "current_phase": "human_review", - "messages": [{ - "role": "system", - "content": "等待人工审核", - "findings_for_review": state.get("verified_findings", []), - }], - } - diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py deleted file mode 100644 index e0afd12..0000000 --- a/backend/app/services/agent/graph/runner.py +++ /dev/null @@ -1,1042 +0,0 @@ -""" -DeepAudit LangGraph Runner -基于 LangGraph 的 Agent 审计执行器 -""" - -import asyncio -import logging -import os -import uuid -from datetime import datetime, timezone -from typing import Dict, List, Optional, Any, AsyncGenerator - -from sqlalchemy.ext.asyncio import AsyncSession - -from langgraph.graph import StateGraph, END -from langgraph.checkpoint.memory import MemorySaver - -from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType -from app.models.agent_task import ( - AgentTask, AgentEvent, AgentFinding, - AgentTaskStatus, AgentTaskPhase, AgentEventType, - VulnerabilitySeverity, VulnerabilityType, FindingStatus, -) -from app.services.agent.event_manager import EventManager, AgentEventEmitter -from app.services.agent.tools import ( - RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool, - PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool, - FileReadTool, FileSearchTool, ListFilesTool, - SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager, - SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool, - TruffleHogTool, OSVScannerTool, -) -from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService -from app.core.config import settings - -from .audit_graph import AuditState, create_audit_graph -from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode - -logger = logging.getLogger(__name__) - - -# 🔥 使用系统统一的 LLMService(支持用户配置) -from app.services.llm.service import LLMService - - -class AgentRunner: - """ - DeepAudit LangGraph Agent Runner - - 基于 LangGraph 状态图的审计执行器 - - 工作流: - START → Recon → Analysis ⟲ → Verification → Report → END - """ - - def __init__( - self, - db: AsyncSession, - task: AgentTask, - project_root: str, - user_config: Optional[Dict[str, Any]] = None, - ): - self.db = db - self.task = task - self.project_root = project_root - - # 🔥 保存用户配置,供 RAG 初始化使用 - self.user_config = user_config or {} - - # 事件管理 - 传入 db_session_factory 以持久化事件 - from app.db.session import async_session_factory - self.event_manager = EventManager(db_session_factory=async_session_factory) - self.event_emitter = AgentEventEmitter(task.id, self.event_manager) - - # 🔥 CRITICAL: 立即创建事件队列,确保在 Agent 开始执行前队列就存在 - # 这样即使前端 SSE 连接稍晚,token 事件也不会丢失 - self.event_manager.create_queue(task.id) - - # 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取) - self.llm_service = LLMService(user_config=self.user_config) - - # 工具集 - self.tools: Dict[str, Any] = {} - - # RAG 组件 - self.retriever: Optional[CodeRetriever] = None - self.indexer: Optional[CodeIndexer] = None - - # 沙箱 - self.sandbox_manager: Optional[SandboxManager] = None - - # LangGraph - self.graph: Optional[StateGraph] = None - self.checkpointer = MemorySaver() - - # 状态 - self._cancelled = False - self._running_task: Optional[asyncio.Task] = None - - # Agent 引用(用于取消传播) - self._agents: List[Any] = [] - - # 流式处理器 - self.stream_handler = StreamHandler(task.id) - - def cancel(self): - """取消任务""" - self._cancelled = True - - # 🔥 取消所有 Agent - for agent in self._agents: - if hasattr(agent, 'cancel'): - agent.cancel() - logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}") - - # 取消运行中的任务 - if self._running_task and not self._running_task.done(): - self._running_task.cancel() - - logger.info(f"Task {self.task.id} cancellation requested") - - @property - def is_cancelled(self) -> bool: - """检查是否已取消""" - return self._cancelled - - async def initialize(self): - """初始化 Runner""" - await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...") - - # 1. 初始化 RAG 系统 - await self._initialize_rag() - - # 2. 初始化工具 - await self._initialize_tools() - - # 3. 构建 LangGraph - await self._build_graph() - - await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成") - - async def _initialize_rag(self): - """初始化 RAG 系统""" - await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...") - - try: - # 🔥 从用户配置中获取配置 - # 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - user_llm_config = self.user_config.get('llmConfig', {}) - user_other_config = self.user_config.get('otherConfig', {}) - user_embedding_config = user_other_config.get('embedding_config', {}) - - # 🔥 Embedding Provider 优先级:用户嵌入配置 > 环境变量 - embedding_provider = ( - user_embedding_config.get('provider') or - getattr(settings, 'EMBEDDING_PROVIDER', 'openai') - ) - - # 🔥 Embedding Model 优先级:用户嵌入配置 > 环境变量 - embedding_model = ( - user_embedding_config.get('model') or - getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') - ) - - # 🔥 API Key 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - embedding_api_key = ( - user_embedding_config.get('api_key') or - user_llm_config.get('llmApiKey') or - getattr(settings, 'LLM_API_KEY', '') or - '' - ) - - # 🔥 Base URL 优先级:用户嵌入配置 > 用户 LLM 配置 > 环境变量 - embedding_base_url = ( - user_embedding_config.get('base_url') or - user_llm_config.get('llmBaseUrl') or - getattr(settings, 'LLM_BASE_URL', None) or - None - ) - - logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}") - await self.event_emitter.emit_info(f"嵌入模型: {embedding_provider}/{embedding_model}") - - embedding_service = EmbeddingService( - provider=embedding_provider, - model=embedding_model, - api_key=embedding_api_key, - base_url=embedding_base_url, - ) - - self.indexer = CodeIndexer( - collection_name=f"project_{self.task.project_id}", - embedding_service=embedding_service, - persist_directory=settings.VECTOR_DB_PATH, - ) - - self.retriever = CodeRetriever( - collection_name=f"project_{self.task.project_id}", - embedding_service=embedding_service, - persist_directory=settings.VECTOR_DB_PATH, - ) - - except Exception as e: - logger.warning(f"RAG initialization failed: {e}") - await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}") - - async def _initialize_tools(self): - """初始化工具集""" - await self.event_emitter.emit_info("初始化 Agent 工具集...") - - # 🔥 导入新工具 - from app.services.agent.tools import ( - ThinkTool, ReflectTool, - CreateVulnerabilityReportTool, - # 多语言代码测试工具 - PhpTestTool, PythonTestTool, JavaScriptTestTool, JavaTestTool, - GoTestTool, RubyTestTool, ShellTestTool, UniversalCodeTestTool, - # 漏洞验证专用工具 - CommandInjectionTestTool, SqlInjectionTestTool, XssTestTool, - PathTraversalTestTool, SstiTestTool, DeserializationTestTool, - UniversalVulnTestTool, - # Kunlun-M 静态代码分析工具 (MIT License) - KunlunMTool, KunlunRuleListTool, KunlunPluginTool, - ) - # 🔥 导入知识查询工具 - from app.services.agent.knowledge import ( - SecurityKnowledgeQueryTool, - GetVulnerabilityKnowledgeTool, - ) - - # 🔥 获取排除模式和目标文件 - exclude_patterns = self.task.exclude_patterns or [] - target_files = self.task.target_files or None - - # ============ 🔥 提前初始化 SandboxManager(供所有外部工具共享)============ - self.sandbox_manager = None - try: - from app.services.agent.tools.sandbox_tool import SandboxConfig - sandbox_config = SandboxConfig( - image=settings.SANDBOX_IMAGE, - memory_limit=settings.SANDBOX_MEMORY_LIMIT, - cpu_limit=settings.SANDBOX_CPU_LIMIT, - timeout=settings.SANDBOX_TIMEOUT, - network_mode=settings.SANDBOX_NETWORK_MODE, - ) - self.sandbox_manager = SandboxManager(config=sandbox_config) - # 🔥 必须调用 initialize() 来连接 Docker - await self.sandbox_manager.initialize() - logger.info(f"✅ SandboxManager initialized early (Docker available: {self.sandbox_manager.is_available})") - except Exception as e: - logger.warning(f"❌ Early Sandbox Manager initialization failed: {e}") - import traceback - logger.warning(f"Traceback: {traceback.format_exc()}") - # 尝试创建默认管理器作为后备 - try: - self.sandbox_manager = SandboxManager() - await self.sandbox_manager.initialize() - logger.info(f"⚠️ Created fallback SandboxManager (Docker available: {self.sandbox_manager.is_available})") - except Exception as e2: - logger.error(f"❌ Failed to create fallback SandboxManager: {e2}") - - # ============ 基础工具(所有 Agent 共享)============ - base_tools = { - "read_file": FileReadTool(self.project_root, exclude_patterns, target_files), - "list_files": ListFilesTool(self.project_root, exclude_patterns, target_files), - # 🔥 新增:思考工具(所有Agent可用) - "think": ThinkTool(), - } - - # ============ Recon Agent 专属工具 ============ - # 职责:信息收集、项目结构分析、技术栈识别 - # 🔥 新增:外部工具也可用于Recon阶段的快速扫描 - self.recon_tools = { - **base_tools, - "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - # 🔥 外部安全工具(共享 SandboxManager 实例) - "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager), - "bandit_scan": BanditTool(self.project_root, self.sandbox_manager), - "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager), - "safety_scan": SafetyTool(self.project_root, self.sandbox_manager), - "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager), - } - - # RAG 工具(Recon 用于语义搜索) - if self.retriever: - self.recon_tools["rag_query"] = RAGQueryTool(self.retriever) - logger.info("✅ RAG 工具已注册到 Recon Agent") - else: - logger.warning("⚠️ RAG 未初始化,rag_query 工具不可用") - - # ============ Analysis Agent 专属工具 ============ - # 职责:漏洞分析、代码审计、模式匹配 - self.analysis_tools = { - **base_tools, - "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), - # 模式匹配和代码分析 - "pattern_match": PatternMatchTool(self.project_root), - # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败 - # "code_analysis": CodeAnalysisTool(self.llm_service), - "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), - # 🔥 外部静态分析工具(共享 SandboxManager 实例) - "semgrep_scan": SemgrepTool(self.project_root, self.sandbox_manager), - "bandit_scan": BanditTool(self.project_root, self.sandbox_manager), - "gitleaks_scan": GitleaksTool(self.project_root, self.sandbox_manager), - "trufflehog_scan": TruffleHogTool(self.project_root, self.sandbox_manager), - "npm_audit": NpmAuditTool(self.project_root, self.sandbox_manager), - "safety_scan": SafetyTool(self.project_root, self.sandbox_manager), - "osv_scan": OSVScannerTool(self.project_root, self.sandbox_manager), - # 🔥 Kunlun-M 静态代码分析工具 (MIT License - https://github.com/LoRexxar/Kunlun-M) - "kunlun_scan": KunlunMTool(self.project_root), - "kunlun_list_rules": KunlunRuleListTool(self.project_root), - "kunlun_plugin": KunlunPluginTool(self.project_root), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - # 🔥 新增:安全知识查询工具(基于RAG) - "query_security_knowledge": SecurityKnowledgeQueryTool(), - "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), - } - - # RAG 工具(Analysis 用于安全相关代码搜索) - if self.retriever: - self.analysis_tools["rag_query"] = RAGQueryTool(self.retriever) # 通用语义搜索 - self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) # 安全代码搜索 - self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) # 函数上下文 - logger.info("✅ RAG 工具已注册到 Analysis Agent (rag_query, security_search, function_context)") - - # ============ Verification Agent 专属工具 ============ - # 职责:漏洞验证、PoC 执行、误报排除 - self.verification_tools = { - **base_tools, - # 验证工具 - 移除旧的 vulnerability_validation 和 dataflow_analysis,强制使用沙箱 - # 🔥 新增:漏洞报告工具(仅Verification可用)- v2.1: 传递 project_root - "create_vulnerability_report": CreateVulnerabilityReportTool(self.project_root), - # 🔥 新增:反思工具 - "reflect": ReflectTool(), - } - - # 🔥 注册沙箱工具(使用提前初始化的 SandboxManager) - if self.sandbox_manager: - # 🔥 沙箱核心工具 - self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) - self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) - self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) - - # 🔥 多语言代码测试工具 - self.verification_tools["php_test"] = PhpTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["python_test"] = PythonTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["javascript_test"] = JavaScriptTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["java_test"] = JavaTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["go_test"] = GoTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["ruby_test"] = RubyTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["shell_test"] = ShellTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["universal_code_test"] = UniversalCodeTestTool(self.sandbox_manager, self.project_root) - - # 🔥 漏洞验证专用工具 - self.verification_tools["test_command_injection"] = CommandInjectionTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_sql_injection"] = SqlInjectionTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_xss"] = XssTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_path_traversal"] = PathTraversalTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_ssti"] = SstiTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["test_deserialization"] = DeserializationTestTool(self.sandbox_manager, self.project_root) - self.verification_tools["universal_vuln_test"] = UniversalVulnTestTool(self.sandbox_manager, self.project_root) - - logger.info(f"✅ Sandbox tools initialized (Docker available: {self.sandbox_manager.is_available})") - else: - logger.error("❌ Sandbox tools NOT initialized due to critical manager failure") - - logger.info(f"✅ Verification tools: {list(self.verification_tools.keys())}") - - # 统计总工具数 - total_tools = len(set( - list(self.recon_tools.keys()) + - list(self.analysis_tools.keys()) + - list(self.verification_tools.keys()) - )) - await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具") - - async def _build_graph(self): - """构建 LangGraph 审计图""" - await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...") - - # 导入 Agent - from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent - - # 创建 Agent 实例(每个 Agent 使用专属工具集) - recon_agent = ReconAgent( - llm_service=self.llm_service, - tools=self.recon_tools, # Recon 专属工具 - event_emitter=self.event_emitter, - ) - - analysis_agent = AnalysisAgent( - llm_service=self.llm_service, - tools=self.analysis_tools, # Analysis 专属工具 - event_emitter=self.event_emitter, - ) - - verification_agent = VerificationAgent( - llm_service=self.llm_service, - tools=self.verification_tools, # Verification 专属工具 - event_emitter=self.event_emitter, - ) - - # 🔥 保存 Agent 引用以便取消时传播信号 - self._agents = [recon_agent, analysis_agent, verification_agent] - - # 创建节点 - recon_node = ReconNode(recon_agent, self.event_emitter) - analysis_node = AnalysisNode(analysis_agent, self.event_emitter) - verification_node = VerificationNode(verification_agent, self.event_emitter) - report_node = ReportNode(None, self.event_emitter) - - # 构建图 - self.graph = create_audit_graph( - recon_node=recon_node, - analysis_node=analysis_node, - verification_node=verification_node, - report_node=report_node, - checkpointer=self.checkpointer, - ) - - await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成") - - async def run(self) -> Dict[str, Any]: - """ - 执行 LangGraph 审计 - - Returns: - 最终状态 - """ - final_state = {} - try: - async for event in self.run_with_streaming(): - # 收集最终状态 - if event.event_type == StreamEventType.TASK_COMPLETE: - final_state = event.data - elif event.event_type == StreamEventType.TASK_ERROR: - final_state = {"success": False, "error": event.data.get("error")} - except Exception as e: - logger.error(f"Agent run failed: {e}", exc_info=True) - final_state = {"success": False, "error": str(e)} - - return final_state - - async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]: - """ - 带流式输出的审计执行 - - Yields: - StreamEvent: 流式事件(包含 LLM 思考、工具调用等) - """ - import time - start_time = time.time() - - try: - # 初始化 - await self.initialize() - - # 更新任务状态 - await self._update_task_status(AgentTaskStatus.RUNNING) - - # 发射任务开始事件 - yield StreamEvent( - event_type=StreamEventType.TASK_START, - sequence=self.stream_handler._next_sequence(), - data={"task_id": self.task.id, "message": "🚀 审计任务开始"}, - ) - - # 1. 索引代码 - await self._index_code() - - if self._cancelled: - yield StreamEvent( - event_type=StreamEventType.TASK_CANCEL, - sequence=self.stream_handler._next_sequence(), - data={"message": "任务已取消"}, - ) - return - - # 2. 收集项目信息 - project_info = await self._collect_project_info() - - # 3. 构建初始状态 - task_config = { - "target_vulnerabilities": self.task.target_vulnerabilities or [], - "verification_level": self.task.verification_level or "sandbox", - "exclude_patterns": self.task.exclude_patterns or [], - "target_files": self.task.target_files or [], - "max_iterations": self.task.max_iterations or 50, - "timeout_seconds": self.task.timeout_seconds or 1800, - } - - initial_state: AuditState = { - "project_root": self.project_root, - "project_info": project_info, - "config": task_config, - "task_id": self.task.id, - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "_verified_findings_update": None, # 🔥 NEW: 验证后的 findings 更新 - "current_phase": "start", - "iteration": 0, - "max_iterations": self.task.max_iterations or 50, - "should_continue_analysis": False, - # 🔥 Agent 协作交接信息 - "recon_handoff": None, - "analysis_handoff": None, - "verification_handoff": None, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - # 4. 执行 LangGraph with astream_events - await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流") - - run_config = { - "configurable": { - "thread_id": self.task.id, - } - } - - final_state = None - - # 使用 astream_events 获取详细事件流 - try: - async for event in self.graph.astream_events( - initial_state, - config=run_config, - version="v2", - ): - if self._cancelled: - break - - # 处理 LangGraph 事件 - stream_event = await self.stream_handler.process_langgraph_event(event) - if stream_event: - # 同步到 event_emitter 以持久化 - await self._sync_stream_event_to_db(stream_event) - yield stream_event - - # 更新最终状态 - if event.get("event") == "on_chain_end": - output = event.get("data", {}).get("output") - if isinstance(output, dict): - final_state = output - - except Exception as e: - # 如果 astream_events 不可用,回退到 astream - logger.warning(f"astream_events not available, falling back to astream: {e}") - async for event in self.graph.astream(initial_state, config=run_config): - if self._cancelled: - break - - for node_name, node_output in event.items(): - await self._handle_node_output(node_name, node_output) - - # 发射节点事件 - yield StreamEvent( - event_type=StreamEventType.NODE_END, - sequence=self.stream_handler._next_sequence(), - node_name=node_name, - data={"message": f"节点 {node_name} 完成"}, - ) - - phase_map = { - "recon": AgentTaskPhase.RECONNAISSANCE, - "analysis": AgentTaskPhase.ANALYSIS, - "verification": AgentTaskPhase.VERIFICATION, - "report": AgentTaskPhase.REPORTING, - } - if node_name in phase_map: - await self._update_task_phase(phase_map[node_name]) - - final_state = node_output - - # 5. 获取最终状态 - # 🔥 CRITICAL FIX: 始终从 graph 获取完整的累积状态 - # 因为每个节点只返回自己的输出,findings 等字段是通过 operator.add 累积的 - # 直接使用 node_output 会丢失之前节点累积的 findings - graph_state = self.graph.get_state(run_config) - if graph_state and graph_state.values: - # 合并完整状态和最后节点的输出 - full_state = graph_state.values - if final_state: - # 保留最后节点的输出(如 summary, security_score) - full_state = {**full_state, **final_state} - final_state = full_state - logger.info(f"[Runner] Got full state from graph with {len(final_state.get('findings', []))} findings") - elif not final_state: - final_state = {} - logger.warning("[Runner] No final state available from graph") - - # 🔥 CRITICAL FIX: 如果有验证后的 findings 更新,使用它替换原始 findings - # 这是因为 LangGraph 的 operator.add 累积器不适合更新已有 findings - verified_findings_update = final_state.get("_verified_findings_update") - if verified_findings_update: - logger.info(f"[Runner] Using verified findings update: {len(verified_findings_update)} findings") - final_state["findings"] = verified_findings_update - else: - # 🔥 FALLBACK: 如果没有 _verified_findings_update,尝试从 verified_findings 合并 - findings = final_state.get("findings", []) - verified_findings = final_state.get("verified_findings", []) - - if verified_findings and findings: - # 创建合并后的 findings 列表 - merged_findings = self._merge_findings_with_verification(findings, verified_findings) - final_state["findings"] = merged_findings - logger.info(f"[Runner] Merged findings: {len(merged_findings)} total") - elif verified_findings and not findings: - # 如果只有 verified_findings,直接使用 - final_state["findings"] = verified_findings - logger.info(f"[Runner] Using verified_findings directly: {len(verified_findings)}") - - logger.info(f"[Runner] Final findings count: {len(final_state.get('findings', []))}") - - # 🔥 检查是否有错误 - error = final_state.get("error") - if error: - # 检查是否是 LLM 认证错误 - error_str = str(error) - if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str: - error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。" - logger.error(f"LLM authentication error: {error}") - else: - error_message = error_str - - duration_ms = int((time.time() - start_time) * 1000) - - # 标记任务为失败 - await self._update_task_status(AgentTaskStatus.FAILED, error_message) - await self.event_emitter.emit_task_error(error_message) - - yield StreamEvent( - event_type=StreamEventType.TASK_ERROR, - sequence=self.stream_handler._next_sequence(), - data={ - "error": error_message, - "message": f"❌ 任务失败: {error_message}", - }, - ) - return - - # 6. 保存发现 - findings = final_state.get("findings", []) - await self._save_findings(findings) - - # 发射发现事件 - for finding in findings[:10]: # 限制数量 - yield self.stream_handler.create_finding_event( - finding, - is_verified=finding.get("is_verified", False), - ) - - # 7. 更新任务摘要 - summary = final_state.get("summary", {}) - security_score = final_state.get("security_score", 100) - - await self._update_task_summary( - total_findings=len(findings), - verified_count=len(final_state.get("verified_findings", [])), - security_score=security_score, - ) - - # 8. 完成 - duration_ms = int((time.time() - start_time) * 1000) - - await self._update_task_status(AgentTaskStatus.COMPLETED) - await self.event_emitter.emit_task_complete( - findings_count=len(findings), - duration_ms=duration_ms, - ) - - yield StreamEvent( - event_type=StreamEventType.TASK_COMPLETE, - sequence=self.stream_handler._next_sequence(), - data={ - "findings_count": len(findings), - "verified_count": len(final_state.get("verified_findings", [])), - "security_score": security_score, - "duration_ms": duration_ms, - "message": f"✅ 审计完成!发现 {len(findings)} 个漏洞", - }, - ) - - except asyncio.CancelledError: - await self._update_task_status(AgentTaskStatus.CANCELLED) - yield StreamEvent( - event_type=StreamEventType.TASK_CANCEL, - sequence=self.stream_handler._next_sequence(), - data={"message": "任务已取消"}, - ) - - except Exception as e: - logger.error(f"LangGraph run failed: {e}", exc_info=True) - await self._update_task_status(AgentTaskStatus.FAILED, str(e)) - await self.event_emitter.emit_error(str(e)) - - yield StreamEvent( - event_type=StreamEventType.TASK_ERROR, - sequence=self.stream_handler._next_sequence(), - data={"error": str(e), "message": f"❌ 审计失败: {e}"}, - ) - - finally: - await self._cleanup() - - async def _sync_stream_event_to_db(self, event: StreamEvent): - """同步流式事件到数据库""" - try: - # 将 StreamEvent 转换为 AgentEventData - await self.event_manager.add_event( - task_id=self.task.id, - event_type=event.event_type.value, - sequence=event.sequence, - phase=event.phase, - message=event.data.get("message"), - tool_name=event.tool_name, - tool_input=event.data.get("input") or event.data.get("input_params"), - tool_output=event.data.get("output") or event.data.get("output_data"), - tool_duration_ms=event.data.get("duration_ms"), - metadata=event.data, - ) - except Exception as e: - logger.warning(f"Failed to sync stream event to db: {e}") - - async def _handle_node_output(self, node_name: str, output: Dict[str, Any]): - """处理节点输出""" - # 发射节点事件 - events = output.get("events", []) - for evt in events: - await self.event_emitter.emit_info( - f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}" - ) - - # 处理新发现 - if node_name == "analysis": - new_findings = output.get("findings", []) - if new_findings: - for finding in new_findings[:5]: # 限制事件数量 - await self.event_emitter.emit_finding( - title=finding.get("title", "Unknown"), - severity=finding.get("severity", "medium"), - file_path=finding.get("file_path"), - ) - - # 处理验证结果 - if node_name == "verification": - verified = output.get("verified_findings", []) - for v in verified[:5]: - await self.event_emitter.emit_info( - f"✅ 已验证: {v.get('title', 'Unknown')}" - ) - - # 处理错误 - if output.get("error"): - await self.event_emitter.emit_error(output["error"]) - - async def _index_code(self): - """索引代码""" - if not self.indexer: - await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引") - return - - await self._update_task_phase(AgentTaskPhase.INDEXING) - await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引") - - try: - async for progress in self.indexer.index_directory(self.project_root): - if self._cancelled: - return - - await self.event_emitter.emit_progress( - progress.processed_files, - progress.total_files, - f"正在索引: {progress.current_file or 'N/A'}" - ) - - await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成") - - except Exception as e: - logger.warning(f"Code indexing failed: {e}") - await self.event_emitter.emit_warning(f"代码索引失败: {e}") - - async def _collect_project_info(self) -> Dict[str, Any]: - """收集项目信息""" - info = { - "name": self.task.project.name if self.task.project else "unknown", - "root": self.project_root, - "languages": [], - "file_count": 0, - } - - try: - exclude_dirs = { - "node_modules", "__pycache__", ".git", "venv", ".venv", - "build", "dist", "target", ".idea", ".vscode", - } - - for root, dirs, files in os.walk(self.project_root): - dirs[:] = [d for d in dirs if d not in exclude_dirs] - info["file_count"] += len(files) - - lang_map = { - ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", - ".java": "Java", ".go": "Go", ".php": "PHP", - ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++", - } - - for f in files: - ext = os.path.splitext(f)[1].lower() - if ext in lang_map and lang_map[ext] not in info["languages"]: - info["languages"].append(lang_map[ext]) - - except Exception as e: - logger.warning(f"Failed to collect project info: {e}") - - return info - - async def _save_findings(self, findings: List[Dict]): - """保存发现到数据库""" - logger.info(f"[Runner] Saving {len(findings)} findings to database for task {self.task.id}") - - if not findings: - logger.info("[Runner] No findings to save") - return - - severity_map = { - "critical": VulnerabilitySeverity.CRITICAL, - "high": VulnerabilitySeverity.HIGH, - "medium": VulnerabilitySeverity.MEDIUM, - "low": VulnerabilitySeverity.LOW, - "info": VulnerabilitySeverity.INFO, - } - - type_map = { - "sql_injection": VulnerabilityType.SQL_INJECTION, - "nosql_injection": VulnerabilityType.NOSQL_INJECTION, - "xss": VulnerabilityType.XSS, - "command_injection": VulnerabilityType.COMMAND_INJECTION, - "code_injection": VulnerabilityType.CODE_INJECTION, - "path_traversal": VulnerabilityType.PATH_TRAVERSAL, - "file_inclusion": VulnerabilityType.FILE_INCLUSION, - "ssrf": VulnerabilityType.SSRF, - "xxe": VulnerabilityType.XXE, - "deserialization": VulnerabilityType.DESERIALIZATION, - "auth_bypass": VulnerabilityType.AUTH_BYPASS, - "idor": VulnerabilityType.IDOR, - "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE, - "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, - "weak_crypto": VulnerabilityType.WEAK_CRYPTO, - "race_condition": VulnerabilityType.RACE_CONDITION, - "business_logic": VulnerabilityType.BUSINESS_LOGIC, - "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION, - } - - for finding in findings: - try: - # 确保 finding 是字典 - if not isinstance(finding, dict): - logger.warning(f"Skipping invalid finding (not a dict): {finding}") - continue - - db_finding = AgentFinding( - id=str(uuid.uuid4()), - task_id=self.task.id, - vulnerability_type=type_map.get( - finding.get("vulnerability_type", "other"), - VulnerabilityType.OTHER - ), - severity=severity_map.get( - finding.get("severity", "medium"), - VulnerabilitySeverity.MEDIUM - ), - title=finding.get("title", "Unknown"), - description=finding.get("description", ""), - file_path=finding.get("file_path"), - line_start=finding.get("line_start"), - line_end=finding.get("line_end"), - code_snippet=finding.get("code_snippet"), - source=finding.get("source"), - sink=finding.get("sink"), - suggestion=finding.get("suggestion") or finding.get("recommendation"), - is_verified=finding.get("is_verified", False), - confidence=finding.get("confidence", 0.5), - poc=finding.get("poc"), - status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW, - ) - - self.db.add(db_finding) - - except Exception as e: - logger.warning(f"Failed to save finding: {e}") - - try: - await self.db.commit() - logger.info(f"[Runner] Successfully saved {len(findings)} findings to database") - except Exception as e: - logger.error(f"Failed to commit findings: {e}") - await self.db.rollback() - - async def _update_task_status( - self, - status: AgentTaskStatus, - error: Optional[str] = None - ): - """更新任务状态""" - self.task.status = status - - if status == AgentTaskStatus.RUNNING: - self.task.started_at = datetime.now(timezone.utc) - elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: - self.task.finished_at = datetime.now(timezone.utc) - - if error: - self.task.error_message = error - - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task status: {e}") - - async def _update_task_phase(self, phase: AgentTaskPhase): - """更新任务阶段""" - self.task.current_phase = phase - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task phase: {e}") - - async def _update_task_summary( - self, - total_findings: int, - verified_count: int, - security_score: int, - ): - """更新任务摘要""" - self.task.total_findings = total_findings - self.task.verified_findings = verified_count - self.task.security_score = security_score - - try: - await self.db.commit() - except Exception as e: - logger.error(f"Failed to update task summary: {e}") - - def _merge_findings_with_verification( - self, - findings: List[Dict], - verified_findings: List[Dict], - ) -> List[Dict]: - """ - 合并原始 findings 和验证结果 - - Args: - findings: 原始 findings 列表 - verified_findings: 验证后的 findings 列表 - - Returns: - 合并后的 findings 列表 - """ - # 创建验证结果的查找映射 - verified_map = {} - for vf in verified_findings: - if not isinstance(vf, dict): - continue - key = ( - vf.get("file_path", ""), - vf.get("line_start", 0), - vf.get("vulnerability_type", ""), - ) - verified_map[key] = vf - - merged = [] - seen_keys = set() - - # 首先处理原始 findings - for f in findings: - if not isinstance(f, dict): - continue - - key = ( - f.get("file_path", ""), - f.get("line_start", 0), - f.get("vulnerability_type", ""), - ) - - if key in verified_map: - # 使用验证后的版本(包含 is_verified, poc 等) - merged.append(verified_map[key]) - else: - # 保留原始 finding - merged.append(f) - - seen_keys.add(key) - - # 添加验证结果中的新发现(如果有) - for key, vf in verified_map.items(): - if key not in seen_keys: - merged.append(vf) - - return merged - - async def _cleanup(self): - """清理资源""" - try: - if self.sandbox_manager: - await self.sandbox_manager.cleanup() - await self.event_manager.close() - except Exception as e: - logger.warning(f"Cleanup error: {e}") - - -# 便捷函数 -async def run_agent_task( - db: AsyncSession, - task: AgentTask, - project_root: str, -) -> Dict[str, Any]: - """ - 运行 Agent 审计任务 - - Args: - db: 数据库会话 - task: Agent 任务 - project_root: 项目根目录 - - Returns: - 审计结果 - """ - runner = AgentRunner(db, task, project_root) - return await runner.run() - diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py index f8b1d49..009cf44 100644 --- a/backend/app/services/agent/tools/file_tool.py +++ b/backend/app/services/agent/tools/file_tool.py @@ -6,6 +6,7 @@ import os import re import fnmatch +import asyncio from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field @@ -44,7 +45,37 @@ class FileReadTool(AgentTool): self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None - + + @staticmethod + def _read_file_lines_sync(file_path: str, start_idx: int, end_idx: int) -> tuple: + """同步读取文件指定行范围(用于 asyncio.to_thread)""" + selected_lines = [] + total_lines = 0 + file_size = os.path.getsize(file_path) + + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + for i, line in enumerate(f): + total_lines = i + 1 + if i >= start_idx and i < end_idx: + selected_lines.append(line) + elif i >= end_idx: + if i < end_idx + 1000: + continue + else: + remaining_bytes = file_size - f.tell() + avg_line_size = f.tell() / (i + 1) + estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0 + total_lines = i + 1 + estimated_remaining_lines + break + + return selected_lines, total_lines + + @staticmethod + def _read_all_lines_sync(file_path: str) -> List[str]: + """同步读取文件所有行(用于 asyncio.to_thread)""" + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.readlines() + @property def name(self) -> str: return "read_file" @@ -136,51 +167,34 @@ class FileReadTool(AgentTool): # 🔥 对于大文件,使用流式读取指定行范围 if is_large_file and (start_line is not None or end_line is not None): - # 流式读取,避免一次性加载整个文件 - selected_lines = [] - total_lines = 0 - # 计算实际的起始和结束行 start_idx = max(0, (start_line or 1) - 1) end_idx = end_line if end_line else start_idx + max_lines - - with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: - for i, line in enumerate(f): - total_lines = i + 1 - if i >= start_idx and i < end_idx: - selected_lines.append(line) - elif i >= end_idx: - # 继续计数以获取总行数,但限制读取量 - if i < end_idx + 1000: # 最多再读1000行来估算总行数 - continue - else: - # 估算剩余行数 - remaining_bytes = file_size - f.tell() - avg_line_size = f.tell() / (i + 1) - estimated_remaining_lines = int(remaining_bytes / avg_line_size) if avg_line_size > 0 else 0 - total_lines = i + 1 + estimated_remaining_lines - break - + + # 异步读取文件,避免阻塞事件循环 + selected_lines, total_lines = await asyncio.to_thread( + self._read_file_lines_sync, full_path, start_idx, end_idx + ) + # 更新实际的结束索引 end_idx = min(end_idx, start_idx + len(selected_lines)) else: - # 正常读取小文件 - with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: - lines = f.readlines() - + # 异步读取小文件,避免阻塞事件循环 + lines = await asyncio.to_thread(self._read_all_lines_sync, full_path) + total_lines = len(lines) - + # 处理行范围 if start_line is not None: start_idx = max(0, start_line - 1) else: start_idx = 0 - + if end_line is not None: end_idx = min(total_lines, end_line) else: end_idx = min(total_lines, start_idx + max_lines) - + # 截取指定行 selected_lines = lines[start_idx:end_idx] @@ -259,7 +273,7 @@ class FileSearchTool(AgentTool): self.project_root = project_root self.exclude_patterns = exclude_patterns or [] self.target_files = set(target_files) if target_files else None - + # 从 exclude_patterns 中提取目录排除 self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS) for pattern in self.exclude_patterns: @@ -267,7 +281,13 @@ class FileSearchTool(AgentTool): self.exclude_dirs.add(pattern[:-3]) elif "/" not in pattern and "*" not in pattern: self.exclude_dirs.add(pattern) - + + @staticmethod + def _read_file_lines_sync(file_path: str) -> List[str]: + """同步读取文件所有行(用于 asyncio.to_thread)""" + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.readlines() + @property def name(self) -> str: return "search_code" @@ -360,11 +380,13 @@ class FileSearchTool(AgentTool): continue try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - lines = f.readlines() - + # 异步读取文件,避免阻塞事件循环 + lines = await asyncio.to_thread( + self._read_file_lines_sync, file_path + ) + files_searched += 1 - + for i, line in enumerate(lines): if pattern.search(line): # 获取上下文 diff --git a/backend/app/services/llm/adapters/litellm_adapter.py b/backend/app/services/llm/adapters/litellm_adapter.py index 842406c..3a78549 100644 --- a/backend/app/services/llm/adapters/litellm_adapter.py +++ b/backend/app/services/llm/adapters/litellm_adapter.py @@ -416,13 +416,93 @@ class LiteLLMAdapter(BaseLLMAdapter): "finish_reason": "complete", } - except Exception as e: - # 🔥 即使出错,也尝试返回估算的 usage - logger.error(f"Stream error: {e}") + except litellm.exceptions.RateLimitError as e: + # 速率限制错误 - 需要特殊处理 + logger.error(f"Stream rate limit error: {e}") + error_msg = str(e) + # 区分"余额不足"和"频率超限" + if any(keyword in error_msg.lower() for keyword in ["余额不足", "资源包", "充值", "quota", "exceeded", "billing"]): + error_type = "quota_exceeded" + user_message = "API 配额已用尽,请检查账户余额或升级计划" + else: + error_type = "rate_limit" + # 尝试从错误消息中提取重试时间 + import re + retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE) + retry_seconds = float(retry_match.group(1)) if retry_match else 60 + user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试" + output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0 yield { "type": "error", + "error_type": error_type, + "error": error_msg, + "user_message": user_message, + "accumulated": accumulated_content, + "usage": { + "prompt_tokens": input_tokens_estimate, + "completion_tokens": output_tokens_estimate, + "total_tokens": input_tokens_estimate + output_tokens_estimate, + } if accumulated_content else None, + } + + except litellm.exceptions.AuthenticationError as e: + # 认证错误 - API Key 无效 + logger.error(f"Stream authentication error: {e}") + yield { + "type": "error", + "error_type": "authentication", "error": str(e), + "user_message": "API Key 无效或已过期,请检查配置", + "accumulated": accumulated_content, + "usage": None, + } + + except litellm.exceptions.APIConnectionError as e: + # 连接错误 - 网络问题 + logger.error(f"Stream connection error: {e}") + yield { + "type": "error", + "error_type": "connection", + "error": str(e), + "user_message": "无法连接到 API 服务,请检查网络连接", + "accumulated": accumulated_content, + "usage": None, + } + + except Exception as e: + # 其他错误 - 检查是否是包装的速率限制错误 + error_msg = str(e) + logger.error(f"Stream error: {e}") + + # 检查是否是包装的速率限制错误(如 ServiceUnavailableError 包装 RateLimitError) + is_rate_limit = any(keyword in error_msg.lower() for keyword in [ + "ratelimiterror", "rate limit", "429", "resource_exhausted", + "quota exceeded", "too many requests" + ]) + + if is_rate_limit: + # 按速率限制错误处理 + import re + # 检查是否是配额用尽 + if any(keyword in error_msg.lower() for keyword in ["quota", "exceeded", "billing"]): + error_type = "quota_exceeded" + user_message = "API 配额已用尽,请检查账户余额或升级计划" + else: + error_type = "rate_limit" + retry_match = re.search(r"retry\s*(?:in|after)\s*(\d+(?:\.\d+)?)\s*s", error_msg, re.IGNORECASE) + retry_seconds = float(retry_match.group(1)) if retry_match else 60 + user_message = f"API 调用频率超限,建议等待 {int(retry_seconds)} 秒后重试" + else: + error_type = "unknown" + user_message = "LLM 调用发生错误,请重试" + + output_tokens_estimate = estimate_tokens(accumulated_content) if accumulated_content else 0 + yield { + "type": "error", + "error_type": error_type, + "error": error_msg, + "user_message": user_message, "accumulated": accumulated_content, "usage": { "prompt_tokens": input_tokens_estimate, diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index d82ba68..bdb15ed 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -739,6 +739,20 @@ class CodeIndexer: self._needs_rebuild = False self._rebuild_reason = "" + @staticmethod + def _read_file_sync(file_path: str) -> str: + """ + 同步读取文件内容(用于 asyncio.to_thread 包装) + + Args: + file_path: 文件路径 + + Returns: + 文件内容 + """ + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]: """ 初始化索引器,检测是否需要重建索引 @@ -916,8 +930,10 @@ class CodeIndexer: try: relative_path = os.path.relpath(file_path, directory) - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) if not content.strip(): progress.processed_files += 1 @@ -932,8 +948,8 @@ class CodeIndexer: if len(content) > 500000: content = content[:500000] - # 分块 - chunks = self.splitter.split_file(content, relative_path) + # 异步分块,避免 Tree-sitter 解析阻塞事件循环 + chunks = await self.splitter.split_file_async(content, relative_path) # 为每个 chunk 添加 file_hash for chunk in chunks: @@ -1018,8 +1034,10 @@ class CodeIndexer: for relative_path in files_to_check: file_path = current_file_map[relative_path] try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) current_hash = hashlib.md5(content.encode()).hexdigest() if current_hash != indexed_file_hashes.get(relative_path): files_to_update.add(relative_path) @@ -1055,8 +1073,10 @@ class CodeIndexer: is_update = relative_path in files_to_update try: - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() + # 异步读取文件,避免阻塞事件循环 + content = await asyncio.to_thread( + self._read_file_sync, file_path + ) if not content.strip(): progress.processed_files += 1 @@ -1075,8 +1095,8 @@ class CodeIndexer: if len(content) > 500000: content = content[:500000] - # 分块 - chunks = self.splitter.split_file(content, relative_path) + # 异步分块,避免 Tree-sitter 解析阻塞事件循环 + chunks = await self.splitter.split_file_async(content, relative_path) # 为每个 chunk 添加 file_hash for chunk in chunks: diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index 184db35..5c350b0 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -4,6 +4,7 @@ """ import re +import asyncio import hashlib import logging from typing import List, Dict, Any, Optional, Tuple, Set @@ -154,7 +155,7 @@ class TreeSitterParser: ".c": "c", ".h": "c", ".hpp": "cpp", - ".cs": "c_sharp", + ".cs": "csharp", ".php": "php", ".rb": "ruby", ".kt": "kotlin", @@ -197,7 +198,7 @@ class TreeSitterParser: # tree-sitter-languages 支持的语言列表 SUPPORTED_LANGUAGES = { "python", "javascript", "typescript", "tsx", "java", "go", "rust", - "c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash", + "c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash", "json", "yaml", "html", "css", "sql", "markdown", } @@ -230,21 +231,30 @@ class TreeSitterParser: return False def parse(self, code: str, language: str) -> Optional[Any]: - """解析代码返回 AST""" + """解析代码返回 AST(同步方法)""" if not self._ensure_initialized(language): return None - + parser = self._parsers.get(language) if not parser: return None - + try: tree = parser.parse(code.encode()) return tree except Exception as e: logger.warning(f"Failed to parse code: {e}") return None - + + async def parse_async(self, code: str, language: str) -> Optional[Any]: + """ + 异步解析代码返回 AST + + 将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行, + 避免阻塞事件循环 + """ + return await asyncio.to_thread(self.parse, code, language) + def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]: """从 AST 提取定义""" if tree is None: @@ -449,9 +459,31 @@ class CodeSplitter: except Exception as e: logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") chunks = self._split_by_lines(content, file_path, language) - + return chunks - + + async def split_file_async( + self, + content: str, + file_path: str, + language: Optional[str] = None + ) -> List[CodeChunk]: + """ + 异步分割单个文件 + + 将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行, + 避免阻塞事件循环。 + + Args: + content: 文件内容 + file_path: 文件路径 + language: 编程语言(可选) + + Returns: + 代码块列表 + """ + return await asyncio.to_thread(self.split_file, content, file_path, language) + def _split_by_ast( self, content: str, diff --git a/backend/tests/agent/test_integration.py b/backend/tests/agent/test_integration.py deleted file mode 100644 index d314404..0000000 --- a/backend/tests/agent/test_integration.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -Agent 集成测试 -测试完整的审计流程 -""" - -import pytest -import asyncio -import os -from unittest.mock import MagicMock, AsyncMock, patch -from datetime import datetime - -from app.services.agent.graph.runner import AgentRunner, LLMService -from app.services.agent.graph.audit_graph import AuditState, create_audit_graph -from app.services.agent.graph.nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode -from app.services.agent.event_manager import EventManager, AgentEventEmitter - - -class TestLLMService: - """LLM 服务测试""" - - @pytest.mark.asyncio - async def test_llm_service_initialization(self): - """测试 LLM 服务初始化""" - with patch("app.core.config.settings") as mock_settings: - mock_settings.LLM_MODEL = "gpt-4o-mini" - mock_settings.LLM_API_KEY = "test-key" - - service = LLMService() - - assert service.model == "gpt-4o-mini" - - -class TestEventManager: - """事件管理器测试""" - - def test_event_manager_initialization(self): - """测试事件管理器初始化""" - manager = EventManager() - - assert manager._event_queues == {} - assert manager._event_callbacks == {} - - @pytest.mark.asyncio - async def test_event_emitter(self): - """测试事件发射器""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_info("Test message") - - assert emitter._sequence == 1 - - @pytest.mark.asyncio - async def test_event_emitter_phase_tracking(self): - """测试事件发射器阶段跟踪""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_phase_start("recon", "开始信息收集") - - assert emitter._current_phase == "recon" - - @pytest.mark.asyncio - async def test_event_emitter_task_complete(self): - """测试任务完成事件""" - manager = EventManager() - emitter = AgentEventEmitter("test-task-id", manager) - - await emitter.emit_task_complete(findings_count=5, duration_ms=1000) - - assert emitter._sequence == 1 - - -class TestAuditGraph: - """审计图测试""" - - def test_create_audit_graph(self, mock_event_emitter): - """测试创建审计图""" - # 创建模拟节点 - recon_node = MagicMock() - analysis_node = MagicMock() - verification_node = MagicMock() - report_node = MagicMock() - - graph = create_audit_graph( - recon_node=recon_node, - analysis_node=analysis_node, - verification_node=verification_node, - report_node=report_node, - ) - - assert graph is not None - - -class TestReconNode: - """Recon 节点测试""" - - @pytest.fixture - def recon_node_with_mock_agent(self, mock_event_emitter): - """创建带模拟 Agent 的 Recon 节点""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=True, - data={ - "tech_stack": {"languages": ["Python"]}, - "entry_points": [{"path": "src/app.py", "type": "api"}], - "high_risk_areas": ["src/sql_vuln.py"], - "dependencies": {}, - "initial_findings": [], - } - )) - - return ReconNode(mock_agent, mock_event_emitter) - - @pytest.mark.asyncio - async def test_recon_node_success(self, recon_node_with_mock_agent): - """测试 Recon 节点成功执行""" - state = { - "project_info": {"name": "Test"}, - "config": {}, - } - - result = await recon_node_with_mock_agent(state) - - assert "tech_stack" in result - assert "entry_points" in result - assert result["current_phase"] == "recon_complete" - - @pytest.mark.asyncio - async def test_recon_node_failure(self, mock_event_emitter): - """测试 Recon 节点失败处理""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=False, - error="Test error", - data=None, - )) - - node = ReconNode(mock_agent, mock_event_emitter) - - result = await node({ - "project_info": {}, - "config": {}, - }) - - assert "error" in result - assert result["current_phase"] == "error" - - -class TestAnalysisNode: - """Analysis 节点测试""" - - @pytest.fixture - def analysis_node_with_mock_agent(self, mock_event_emitter): - """创建带模拟 Agent 的 Analysis 节点""" - mock_agent = MagicMock() - mock_agent.run = AsyncMock(return_value=MagicMock( - success=True, - data={ - "findings": [ - { - "id": "finding-1", - "title": "SQL Injection", - "severity": "high", - "vulnerability_type": "sql_injection", - "file_path": "src/sql_vuln.py", - "line_start": 10, - "description": "SQL injection vulnerability", - } - ], - "should_continue": False, - } - )) - - return AnalysisNode(mock_agent, mock_event_emitter) - - @pytest.mark.asyncio - async def test_analysis_node_success(self, analysis_node_with_mock_agent): - """测试 Analysis 节点成功执行""" - state = { - "project_info": {"name": "Test"}, - "tech_stack": {"languages": ["Python"]}, - "entry_points": [], - "high_risk_areas": ["src/sql_vuln.py"], - "config": {}, - "iteration": 0, - "findings": [], - } - - result = await analysis_node_with_mock_agent(state) - - assert "findings" in result - assert len(result["findings"]) > 0 - assert result["iteration"] == 1 - - -class TestIntegrationFlow: - """完整流程集成测试""" - - @pytest.mark.asyncio - async def test_full_audit_flow_mock(self, temp_project_dir, mock_db_session, mock_task): - """测试完整审计流程(使用模拟)""" - # 这个测试验证整个流程的连接性 - - # 创建事件管理器 - event_manager = EventManager() - emitter = AgentEventEmitter(mock_task.id, event_manager) - - # 模拟 LLM 服务 - mock_llm = MagicMock() - mock_llm.chat_completion_raw = AsyncMock(return_value={ - "content": "Analysis complete", - "usage": {"total_tokens": 100}, - }) - - # 验证事件发射 - await emitter.emit_phase_start("init", "初始化") - await emitter.emit_info("测试消息") - await emitter.emit_phase_complete("init", "初始化完成") - - assert emitter._sequence == 3 - - @pytest.mark.asyncio - async def test_audit_state_typing(self): - """测试审计状态类型定义""" - state: AuditState = { - "project_root": "/tmp/test", - "project_info": {"name": "Test"}, - "config": {}, - "task_id": "test-id", - "tech_stack": {}, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "findings": [], - "verified_findings": [], - "false_positives": [], - "current_phase": "start", - "iteration": 0, - "max_iterations": 50, - "should_continue_analysis": False, - "messages": [], - "events": [], - "summary": None, - "security_score": None, - "error": None, - } - - assert state["current_phase"] == "start" - assert state["max_iterations"] == 50 - - -class TestToolIntegration: - """工具集成测试""" - - @pytest.mark.asyncio - async def test_tools_work_together(self, temp_project_dir): - """测试工具协同工作""" - from app.services.agent.tools import ( - FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool, - ) - - # 1. 列出文件 - list_tool = ListFilesTool(temp_project_dir) - list_result = await list_tool.execute(directory="src", recursive=False) - assert list_result.success is True - - # 2. 搜索关键代码 - search_tool = FileSearchTool(temp_project_dir) - search_result = await search_tool.execute(keyword="execute") - assert search_result.success is True - - # 3. 读取文件内容 - read_tool = FileReadTool(temp_project_dir) - read_result = await read_tool.execute(file_path="src/sql_vuln.py") - assert read_result.success is True - - # 4. 模式匹配 - pattern_tool = PatternMatchTool(temp_project_dir) - pattern_result = await pattern_tool.execute( - code=read_result.data, - file_path="src/sql_vuln.py", - language="python" - ) - assert pattern_result.success is True - - -class TestErrorHandling: - """错误处理测试""" - - @pytest.mark.asyncio - async def test_tool_error_handling(self, temp_project_dir): - """测试工具错误处理""" - from app.services.agent.tools import FileReadTool - - tool = FileReadTool(temp_project_dir) - - # 尝试读取不存在的文件 - result = await tool.execute(file_path="nonexistent/file.py") - - assert result.success is False - assert result.error is not None - - @pytest.mark.asyncio - async def test_agent_graceful_degradation(self, mock_event_emitter): - """测试 Agent 优雅降级""" - # 创建一个会失败的 Agent - mock_agent = MagicMock() - mock_agent.run = AsyncMock(side_effect=Exception("Simulated error")) - - node = ReconNode(mock_agent, mock_event_emitter) - - result = await node({ - "project_info": {}, - "config": {}, - }) - - # 应该返回错误状态而不是崩溃 - assert "error" in result - assert result["current_phase"] == "error" - - -class TestPerformance: - """性能测试""" - - @pytest.mark.asyncio - async def test_tool_response_time(self, temp_project_dir): - """测试工具响应时间""" - from app.services.agent.tools import ListFilesTool - import time - - tool = ListFilesTool(temp_project_dir) - - start = time.time() - await tool.execute(directory=".", recursive=True) - duration = time.time() - start - - # 工具应该在合理时间内响应 - assert duration < 5.0 # 5 秒内 - - @pytest.mark.asyncio - async def test_multiple_tool_calls(self, temp_project_dir): - """测试多次工具调用""" - from app.services.agent.tools import FileSearchTool - - tool = FileSearchTool(temp_project_dir) - - # 执行多次调用 - for _ in range(5): - result = await tool.execute(keyword="def") - assert result.success is True - - # 验证调用计数 - assert tool._call_count == 5 - diff --git a/frontend/src/components/audit/components/BasicConfig.tsx b/frontend/src/components/audit/components/BasicConfig.tsx index c82968e..b613f85 100644 --- a/frontend/src/components/audit/components/BasicConfig.tsx +++ b/frontend/src/components/audit/components/BasicConfig.tsx @@ -9,7 +9,7 @@ import { } from "@/components/ui/select"; import { GitBranch, Zap, Info } from "lucide-react"; import type { Project, CreateAuditTaskForm } from "@/shared/types"; -import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils"; +import { isRepositoryProject, isZipProject, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils"; import ZipFileSection from "./ZipFileSection"; import type { ZipFileMeta } from "@/shared/utils/zipStorage"; @@ -138,7 +138,7 @@ function ProjectInfoCard({ project }: { project: Project }) { {isRepo && ( <>

- 仓库平台:{project.repository_type?.toUpperCase() || "OTHER"} + 仓库平台:{getRepositoryPlatformLabel(project.repository_type)}

默认分支:{project.default_branch}

diff --git a/frontend/src/pages/ProjectDetail.tsx b/frontend/src/pages/ProjectDetail.tsx index 9d282c2..28ea711 100644 --- a/frontend/src/pages/ProjectDetail.tsx +++ b/frontend/src/pages/ProjectDetail.tsx @@ -34,13 +34,13 @@ import { api } from "@/shared/config/database"; import { runRepositoryAudit, scanStoredZipFile } from "@/features/projects/services"; import type { Project, AuditTask, CreateProjectForm } from "@/shared/types"; import { hasZipFile } from "@/shared/utils/zipStorage"; -import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils"; +import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils"; import { toast } from "sonner"; import CreateTaskDialog from "@/components/audit/CreateTaskDialog"; import FileSelectionDialog from "@/components/audit/FileSelectionDialog"; import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog"; import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from "@/components/ui/dialog"; -import { SUPPORTED_LANGUAGES } from "@/shared/constants"; +import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants"; export default function ProjectDetail() { const { id } = useParams<{ id: string }>(); @@ -475,8 +475,7 @@ export default function ProjectDetail() {
仓库平台 - {project.repository_type === 'github' ? 'GitHub' : - project.repository_type === 'gitlab' ? 'GitLab' : '其他'} + {getRepositoryPlatformLabel(project.repository_type)}
@@ -529,12 +528,11 @@ export default function ProjectDetail() { className="flex items-center justify-between p-3 bg-muted/50 rounded-lg hover:bg-muted transition-all group" >
-
+ task.status === 'failed' ? 'bg-rose-500/20' : + 'bg-muted' + }`}> {getStatusIcon(task.status)}
@@ -579,12 +577,11 @@ export default function ProjectDetail() {
-
+ task.status === 'failed' ? 'bg-rose-500/20' : + 'bg-muted' + }`}> {getStatusIcon(task.status)}
@@ -676,12 +673,11 @@ export default function ProjectDetail() {
-
+ issue.severity === 'medium' ? 'bg-amber-500/20 text-amber-600 dark:text-amber-400' : + 'bg-sky-500/20 text-sky-600 dark:text-sky-400' + }`}>
@@ -695,13 +691,13 @@ export default function ProjectDetail() { {issue.severity === 'critical' ? '严重' : issue.severity === 'high' ? '高' : - issue.severity === 'medium' ? '中等' : '低'} + issue.severity === 'medium' ? '中等' : '低'}

@@ -783,9 +779,11 @@ export default function ProjectDetail() { - GitHub - GitLab - 其他 + {REPOSITORY_PLATFORMS.map((platform) => ( + + {platform.label} + + ))}

@@ -831,14 +829,14 @@ export default function ProjectDetail() { className={`flex items-center space-x-2 p-3 border cursor-pointer transition-all rounded ${editForm.programming_languages?.includes(lang) ? 'border-primary bg-primary/10 text-primary' : 'border-border hover:border-border text-muted-foreground' - }`} + }`} onClick={() => handleToggleLanguage(lang)} >
{editForm.programming_languages?.includes(lang) && ( diff --git a/frontend/src/pages/Projects.tsx b/frontend/src/pages/Projects.tsx index 979f188..7a94a48 100644 --- a/frontend/src/pages/Projects.tsx +++ b/frontend/src/pages/Projects.tsx @@ -45,7 +45,7 @@ import { Link } from "react-router-dom"; import { toast } from "sonner"; import CreateTaskDialog from "@/components/audit/CreateTaskDialog"; import TerminalProgressDialog from "@/components/audit/TerminalProgressDialog"; -import { SUPPORTED_LANGUAGES } from "@/shared/constants"; +import { SUPPORTED_LANGUAGES, REPOSITORY_PLATFORMS } from "@/shared/constants"; export default function Projects() { const [projects, setProjects] = useState([]); @@ -487,10 +487,11 @@ export default function Projects() { - GITHUB - GITLAB - GITEA - OTHER + {REPOSITORY_PLATFORMS.map((platform) => ( + + {platform.label} + + ))}
@@ -1046,10 +1047,11 @@ export default function Projects() { - GITHUB - GITLAB - GITEA - OTHER + {REPOSITORY_PLATFORMS.map((platform) => ( + + {platform.label} + + ))}
diff --git a/frontend/src/pages/TaskDetail.tsx b/frontend/src/pages/TaskDetail.tsx index 0d05359..21d3372 100644 --- a/frontend/src/pages/TaskDetail.tsx +++ b/frontend/src/pages/TaskDetail.tsx @@ -36,7 +36,7 @@ import type { AuditTask, AuditIssue } from "@/shared/types"; import { toast } from "sonner"; import ExportReportDialog from "@/components/reports/ExportReportDialog"; import { calculateTaskProgress } from "@/shared/utils/utils"; -import { isRepositoryProject, getSourceTypeLabel } from "@/shared/utils/projectUtils"; +import { isRepositoryProject, getSourceTypeLabel, getRepositoryPlatformLabel } from "@/shared/utils/projectUtils"; // AI explanation parser function parseAIExplanation(aiExplanation: string) { @@ -86,12 +86,11 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) {
-
+
{getTypeIcon(issue.issue_type)}
@@ -112,7 +111,7 @@ function IssuesList({ issues }: { issues: AuditIssue[] }) { {issue.severity === 'critical' ? '严重' : issue.severity === 'high' ? '高' : - issue.severity === 'medium' ? '中等' : '低'} + issue.severity === 'medium' ? '中等' : '低'}
@@ -702,7 +701,7 @@ export default function TaskDetail() { {isRepositoryProject(task.project) && (

仓库平台

-

{task.project.repository_type?.toUpperCase() || 'OTHER'}

+

{getRepositoryPlatformLabel(task.project.repository_type)}

)} {task.project.programming_languages && ( diff --git a/frontend/src/shared/constants/index.ts b/frontend/src/shared/constants/index.ts index 2f48818..401ecb2 100644 --- a/frontend/src/shared/constants/index.ts +++ b/frontend/src/shared/constants/index.ts @@ -62,13 +62,6 @@ export const PROJECT_SOURCE_TYPES = { ZIP: 'zip', } as const; -// 仓库平台类型 -export const REPOSITORY_TYPES = { - GITHUB: 'github', - GITLAB: 'gitlab', - OTHER: 'other', -} as const; - // 分析深度 export const ANALYSIS_DEPTH = { BASIC: 'basic', diff --git a/frontend/src/shared/constants/projectTypes.ts b/frontend/src/shared/constants/projectTypes.ts index 55e6d47..4e78e9b 100644 --- a/frontend/src/shared/constants/projectTypes.ts +++ b/frontend/src/shared/constants/projectTypes.ts @@ -22,17 +22,23 @@ export const PROJECT_SOURCE_TYPES: Array<{ } ]; +// 仓库平台显示名称 +export const REPOSITORY_PLATFORM_LABELS: Record = { + github: 'GitHub', + gitlab: 'GitLab', + gitea: 'Gitea', + other: '其他', +}; + // 仓库平台选项 export const REPOSITORY_PLATFORMS: Array<{ value: RepositoryPlatform; label: string; icon?: string; -}> = [ - { value: 'github', label: 'GitHub' }, - { value: 'gitlab', label: 'GitLab' }, - { value: 'gitea', label: 'Gitea' }, - { value: 'other', label: '其他' } - ]; +}> = Object.entries(REPOSITORY_PLATFORM_LABELS).map(([value, label]) => ({ + value: value as RepositoryPlatform, + label +})); // 项目来源类型的颜色配置 export const SOURCE_TYPE_COLORS: Record = { - github: 'GitHub', - gitlab: 'GitLab', - gitea: 'Gitea', - other: '其他' - }; - return labels[platform || 'other'] || '其他'; + return REPOSITORY_PLATFORM_LABELS[platform as keyof typeof REPOSITORY_PLATFORM_LABELS] || REPOSITORY_PLATFORM_LABELS.other; } /**