diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index a9f8880..502d4fa 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -75,18 +75,46 @@ class AgentTaskCreate(BaseModel): class AgentTaskResponse(BaseModel): - """Agent 任务响应""" + """Agent 任务响应 - 包含所有前端需要的字段""" id: str project_id: str name: Optional[str] description: Optional[str] + task_type: str = "agent_audit" status: str current_phase: Optional[str] + current_step: Optional[str] = None - # 统计 - total_findings: int = 0 - verified_findings: int = 0 - security_score: Optional[int] = None + # 进度统计 + total_files: int = 0 + indexed_files: int = 0 + analyzed_files: int = 0 + total_chunks: int = 0 + + # Agent 统计 + total_iterations: int = 0 + tool_calls_count: int = 0 + tokens_used: int = 0 + + # 发现统计(兼容两种命名) + findings_count: int = 0 + total_findings: int = 0 # 兼容字段 + verified_count: int = 0 + verified_findings: int = 0 # 兼容字段 + false_positive_count: int = 0 + + # 严重程度统计 + critical_count: int = 0 + high_count: int = 0 + medium_count: int = 0 + low_count: int = 0 + + # 评分 + quality_score: float = 0.0 + security_score: Optional[float] = None + + # 进度百分比 + progress_percentage: float = 0.0 # 时间 created_at: datetime @@ -94,7 +122,11 @@ class AgentTaskResponse(BaseModel): completed_at: Optional[datetime] = None # 配置 - config: Optional[dict] = None + audit_scope: Optional[dict] = None + target_vulnerabilities: Optional[List[str]] = None + verification_level: Optional[str] = None + exclude_patterns: Optional[List[str]] = None + target_files: Optional[List[str]] = None # 错误信息 error_message: Optional[str] = None @@ -177,6 +209,12 @@ async def _execute_agent_task(task_id: str, project_root: str): logger.error(f"Task {task_id} not found") return + # 更新状态为运行中 + task.status = AgentTaskStatus.RUNNING + task.started_at = datetime.now(timezone.utc) + await db.commit() + logger.info(f"Task {task_id} started") + # 创建 Runner runner = AgentRunner(db, task, project_root) _running_tasks[task_id] = runner @@ -184,22 +222,37 @@ async def _execute_agent_task(task_id: str, project_root: str): # 执行 result = await runner.run() - logger.info(f"Task {task_id} completed: {result.get('success', False)}") + # 更新任务状态 + await db.refresh(task) + if result.get('success', True): # 默认成功,除非明确失败 + task.status = AgentTaskStatus.COMPLETED + task.completed_at = datetime.now(timezone.utc) + else: + task.status = AgentTaskStatus.FAILED + task.error_message = result.get('error', 'Unknown error') + task.completed_at = datetime.now(timezone.utc) + + await db.commit() + logger.info(f"Task {task_id} completed with status: {task.status}") except Exception as e: logger.error(f"Task {task_id} failed: {e}", exc_info=True) # 更新任务状态 - task = await db.get(AgentTask, task_id) - if task: - task.status = AgentTaskStatus.FAILED - task.error_message = str(e) - task.completed_at = datetime.now(timezone.utc) - await db.commit() + try: + task = await db.get(AgentTask, task_id) + if task: + task.status = AgentTaskStatus.FAILED + task.error_message = str(e)[:1000] # 限制错误消息长度 + task.completed_at = datetime.now(timezone.utc) + await db.commit() + except Exception as db_error: + logger.error(f"Failed to update task status: {db_error}") finally: # 清理 _running_tasks.pop(task_id, None) + logger.debug(f"Task {task_id} cleaned up") # ============ API Endpoints ============ @@ -315,7 +368,61 @@ async def get_agent_task( if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") - return task + # 构建响应,确保所有字段都包含 + try: + # 计算进度百分比 + progress = 0.0 + if hasattr(task, 'progress_percentage'): + progress = task.progress_percentage + elif task.status == AgentTaskStatus.COMPLETED: + progress = 100.0 + elif task.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + progress = 0.0 + + # 手动构建响应数据 + response_data = { + "id": task.id, + "project_id": task.project_id, + "name": task.name, + "description": task.description, + "task_type": task.task_type or "agent_audit", + "status": task.status, + "current_phase": task.current_phase, + "current_step": task.current_step, + "total_files": task.total_files or 0, + "indexed_files": task.indexed_files or 0, + "analyzed_files": task.analyzed_files or 0, + "total_chunks": task.total_chunks or 0, + "total_iterations": task.total_iterations or 0, + "tool_calls_count": task.tool_calls_count or 0, + "tokens_used": task.tokens_used or 0, + "findings_count": task.findings_count or 0, + "total_findings": task.findings_count or 0, # 兼容字段 + "verified_count": task.verified_count or 0, + "verified_findings": task.verified_count or 0, # 兼容字段 + "false_positive_count": task.false_positive_count or 0, + "critical_count": task.critical_count or 0, + "high_count": task.high_count or 0, + "medium_count": task.medium_count or 0, + "low_count": task.low_count or 0, + "quality_score": float(task.quality_score or 0.0), + "security_score": float(task.security_score) if task.security_score is not None else None, + "progress_percentage": progress, + "created_at": task.created_at, + "started_at": task.started_at, + "completed_at": task.completed_at, + "error_message": task.error_message, + "audit_scope": task.audit_scope, + "target_vulnerabilities": task.target_vulnerabilities, + "verification_level": task.verification_level, + "exclude_patterns": task.exclude_patterns, + "target_files": task.target_files, + } + + return AgentTaskResponse(**response_data) + except Exception as e: + logger.error(f"Error serializing task {task_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"序列化任务数据失败: {str(e)}") @router.post("/{task_id}/cancel") @@ -396,10 +503,14 @@ async def stream_agent_events( idle_time = 0 for event in events: last_sequence = event.sequence + # event_type 已经是字符串,不需要 .value + event_type_str = str(event.event_type) + phase_str = str(event.phase) if event.phase else None + data = { "id": event.id, - "type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type), - "phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase, + "type": event_type_str, + "phase": phase_str, "message": event.message, "sequence": event.sequence, "timestamp": event.created_at.isoformat() if event.created_at else None, @@ -411,9 +522,12 @@ async def stream_agent_events( idle_time += poll_interval # 检查任务是否结束 - if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: - yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n" - break + if task_status: + # task_status 可能是字符串或枚举,统一转换为字符串 + status_str = str(task_status) + if status_str in ["completed", "failed", "cancelled"]: + yield f"data: {json.dumps({'type': 'task_end', 'status': status_str})}\n\n" + break # 检查空闲超时 if idle_time >= max_idle: @@ -512,8 +626,8 @@ async def stream_agent_with_thinking( for event in events: last_sequence = event.sequence - # 获取事件类型字符串 - event_type = event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type) + # 获取事件类型字符串(event_type 已经是字符串) + event_type = str(event.event_type) # 过滤事件 if event_type in skip_types: @@ -523,7 +637,7 @@ async def stream_agent_with_thinking( data = { "id": event.id, "type": event_type, - "phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase, + "phase": str(event.phase) if event.phase else None, "message": event.message, "sequence": event.sequence, "timestamp": event.created_at.isoformat() if event.created_at else None, @@ -552,14 +666,16 @@ async def stream_agent_with_thinking( idle_time += poll_interval # 检查任务是否结束 - if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: - end_data = { - "type": "task_end", - "status": task_status.value, - "message": f"任务{'完成' if task_status == AgentTaskStatus.COMPLETED else '结束'}", - } - yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - break + if task_status: + status_str = str(task_status) + if status_str in ["completed", "failed", "cancelled"]: + end_data = { + "type": "task_end", + "status": status_str, + "message": f"任务{'完成' if status_str == 'completed' else '结束'}", + } + yield f"event: task_end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" + break # 发送心跳 last_heartbeat += poll_interval @@ -709,8 +825,9 @@ async def get_task_summary( verified_count = 0 for f in findings: - sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity) - vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type) + # severity 和 vulnerability_type 已经是字符串 + sev = str(f.severity) + vtype = str(f.vulnerability_type) severity_distribution[sev] = severity_distribution.get(sev, 0) + 1 vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1 @@ -730,11 +847,11 @@ async def get_task_summary( .where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE) .distinct() ) - phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]] + phases = [str(p[0]) for p in phases_result.fetchall() if p[0]] return TaskSummaryResponse( task_id=task_id, - status=task.status.value if hasattr(task.status, 'value') else str(task.status), + status=str(task.status), # status 已经是字符串 security_score=task.security_score, total_findings=len(findings), verified_findings=verified_count, diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py index e429bf8..8b0b86e 100644 --- a/backend/app/services/agent/agents/analysis.py +++ b/backend/app/services/agent/agents/analysis.py @@ -1,94 +1,153 @@ """ -Analysis Agent (漏洞分析层) -负责代码审计、RAG 查询、模式匹配、数据流分析 +Analysis Agent (漏洞分析层) - LLM 驱动版 -类型: ReAct +LLM 是真正的安全分析大脑! +- LLM 决定分析策略 +- LLM 选择使用什么工具 +- LLM 决定深入分析哪些代码 +- LLM 判断发现的问题是否是真实漏洞 + +类型: ReAct (真正的!) """ -import asyncio +import json import logging +import re from typing import List, Dict, Any, Optional +from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern logger = logging.getLogger(__name__) -ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,负责深度代码安全分析。 +ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自主**的安全专家。 -## 你的职责 -1. 使用静态分析工具快速扫描 -2. 使用 RAG 进行语义代码搜索 -3. 追踪数据流(从用户输入到危险函数) -4. 分析业务逻辑漏洞 -5. 评估漏洞严重程度 +## 你的角色 +你是安全审计的**核心大脑**,不是工具执行器。你需要: +1. 自主制定分析策略 +2. 选择最有效的工具和方法 +3. 深入分析可疑代码 +4. 判断是否是真实漏洞 +5. 动态调整分析方向 ## 你可以使用的工具 + ### 外部扫描工具 -- semgrep_scan: Semgrep 静态分析(推荐首先使用) -- bandit_scan: Python 安全扫描 +- **semgrep_scan**: Semgrep 静态分析(推荐首先使用) + 参数: rules (str), max_results (int) +- **bandit_scan**: Python 安全扫描 ### RAG 语义搜索 -- rag_query: 语义代码搜索 -- security_search: 安全相关代码搜索 -- function_context: 函数上下文分析 +- **rag_query**: 语义代码搜索 + 参数: query (str), top_k (int) +- **security_search**: 安全相关代码搜索 + 参数: vulnerability_type (str), top_k (int) +- **function_context**: 函数上下文分析 + 参数: function_name (str) ### 深度分析 -- pattern_match: 危险模式匹配 -- code_analysis: LLM 深度代码分析 -- dataflow_analysis: 数据流追踪 -- vulnerability_validation: 漏洞验证 +- **pattern_match**: 危险模式匹配 + 参数: pattern (str), file_types (list) +- **code_analysis**: LLM 深度代码分析 ⭐ + 参数: code (str), file_path (str), focus (str) +- **dataflow_analysis**: 数据流追踪 + 参数: source (str), sink (str) +- **vulnerability_validation**: 漏洞验证 + 参数: code (str), vulnerability_type (str) ### 文件操作 -- read_file: 读取文件 -- search_code: 关键字搜索 +- **read_file**: 读取文件内容 + 参数: file_path (str), start_line (int), end_line (int) +- **search_code**: 代码关键字搜索 + 参数: keyword (str), max_results (int) +- **list_files**: 列出目录文件 + 参数: directory (str), pattern (str) -## 分析策略 -1. **快速扫描**: 先用 Semgrep 快速发现问题 -2. **语义搜索**: 用 RAG 找到相关代码 -3. **深度分析**: 对可疑代码进行 LLM 分析 -4. **数据流追踪**: 追踪用户输入的流向 +## 工作方式 +每一步,你需要输出: -## 重点关注 -- SQL 注入、NoSQL 注入 -- XSS(反射型、存储型、DOM型) -- 命令注入、代码注入 -- 路径遍历、任意文件访问 -- SSRF、XXE -- 不安全的反序列化 -- 认证/授权绕过 -- 敏感信息泄露 +``` +Thought: [分析当前情况,思考下一步应该做什么] +Action: [工具名称] +Action Input: [JSON 格式的参数] +``` -## 输出格式 -发现漏洞时,返回结构化信息: +当你完成分析后,输出: + +``` +Thought: [总结所有发现] +Final Answer: [JSON 格式的漏洞报告] +``` + +## Final Answer 格式 ```json { "findings": [ { - "vulnerability_type": "漏洞类型", - "severity": "critical/high/medium/low", - "title": "漏洞标题", + "vulnerability_type": "sql_injection", + "severity": "high", + "title": "SQL 注入漏洞", "description": "详细描述", - "file_path": "文件路径", - "line_start": 行号, - "code_snippet": "代码片段", - "source": "污点源", + "file_path": "path/to/file.py", + "line_start": 42, + "code_snippet": "危险代码片段", + "source": "污点来源", "sink": "危险函数", "suggestion": "修复建议", - "needs_verification": true/false + "confidence": 0.9, + "needs_verification": true } - ] + ], + "summary": "分析总结" } ``` -请系统性地分析代码,发现真实的安全漏洞。""" +## 分析策略建议 +1. **快速扫描**: 先用 semgrep_scan 获得概览 +2. **重点深入**: 对可疑文件使用 read_file + code_analysis +3. **模式搜索**: 用 search_code 找危险模式 (eval, exec, query 等) +4. **语义搜索**: 用 RAG 找相似的漏洞模式 +5. **数据流**: 用 dataflow_analysis 追踪用户输入 + +## 重点关注的漏洞类型 +- SQL 注入 (query, execute, raw SQL) +- XSS (innerHTML, document.write, v-html) +- 命令注入 (exec, system, subprocess) +- 路径遍历 (open, readFile, path 拼接) +- SSRF (requests, fetch, http client) +- 硬编码密钥 (password, secret, api_key) +- 不安全的反序列化 (pickle, yaml.load, eval) + +## 重要原则 +1. **质量优先** - 宁可深入分析几个真实漏洞,不要浅尝辄止报告大量误报 +2. **上下文分析** - 看到可疑代码要读取上下文,理解完整逻辑 +3. **自主判断** - 不要机械相信工具输出,要用你的专业知识判断 +4. **持续探索** - 发现一个问题后,思考是否有相关问题 + +现在开始你的安全分析!""" + + +@dataclass +class AnalysisStep: + """分析步骤""" + thought: str + action: Optional[str] = None + action_input: Optional[Dict] = None + observation: Optional[str] = None + is_final: bool = False + final_answer: Optional[Dict] = None class AnalysisAgent(BaseAgent): """ - 漏洞分析 Agent + 漏洞分析 Agent - LLM 驱动版 - 使用 ReAct 模式进行迭代分析 + LLM 全程参与,自主决定: + 1. 分析什么 + 2. 使用什么工具 + 3. 深入哪些代码 + 4. 报告什么发现 """ def __init__( @@ -103,87 +162,288 @@ class AnalysisAgent(BaseAgent): pattern=AgentPattern.REACT, max_iterations=30, system_prompt=ANALYSIS_SYSTEM_PROMPT, - tools=[ - "semgrep_scan", "bandit_scan", - "rag_query", "security_search", "function_context", - "pattern_match", "code_analysis", "dataflow_analysis", - "vulnerability_validation", - "read_file", "search_code", - ], ) super().__init__(config, llm_service, tools, event_emitter) + + self._conversation_history: List[Dict[str, str]] = [] + self._steps: List[AnalysisStep] = [] + + def _get_tools_description(self) -> str: + """生成工具描述""" + tools_info = [] + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + desc = f"- {name}: {getattr(tool, 'description', 'No description')}" + tools_info.append(desc) + return "\n".join(tools_info) + + def _parse_llm_response(self, response: str) -> AnalysisStep: + """解析 LLM 响应""" + step = AnalysisStep(thought="") + + # 提取 Thought + thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL) + if thought_match: + step.thought = thought_match.group(1).strip() + + # 检查是否是最终答案 + final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) + if final_match: + step.is_final = True + try: + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + step.final_answer = json.loads(answer_text) + except json.JSONDecodeError: + step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()} + return step + + # 提取 Action + action_match = re.search(r'Action:\s*(\w+)', response) + if action_match: + step.action = action_match.group(1).strip() + + # 提取 Action Input + input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) + if input_match: + input_text = input_match.group(1).strip() + input_text = re.sub(r'```json\s*', '', input_text) + input_text = re.sub(r'```\s*', '', input_text) + try: + step.action_input = json.loads(input_text) + except json.JSONDecodeError: + step.action_input = {"raw_input": input_text} + + return step + + async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: + """执行工具""" + tool = self.tools.get(tool_name) + + if not tool: + return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" + + try: + self._tool_calls += 1 + await self.emit_tool_call(tool_name, tool_input) + + import time + start = time.time() + + result = await tool.execute(**tool_input) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + if result.success: + output = str(result.data) + + # 如果是代码分析工具,也包含 metadata + if result.metadata: + if "issues" in result.metadata: + output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}" + if "findings" in result.metadata: + output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}" + + if len(output) > 6000: + output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]" + return output + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"Tool execution error: {e}") + return f"工具执行错误: {str(e)}" async def run(self, input_data: Dict[str, Any]) -> AgentResult: - """执行漏洞分析""" + """ + 执行漏洞分析 - LLM 全程参与! + """ import time start_time = time.time() - phase_name = input_data.get("phase_name", "analysis") project_info = input_data.get("project_info", {}) config = input_data.get("config", {}) plan = input_data.get("plan", {}) previous_results = input_data.get("previous_results", {}) + task = input_data.get("task", "") + task_context = input_data.get("task_context", "") + + # 从 Recon 结果获取上下文 + recon_data = previous_results.get("recon", {}) + if isinstance(recon_data, dict) and "data" in recon_data: + recon_data = recon_data["data"] - # 从之前的 Recon 结果获取信息 - recon_data = previous_results.get("recon", {}).get("data", {}) - high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", [])) tech_stack = recon_data.get("tech_stack", {}) entry_points = recon_data.get("entry_points", []) + high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", [])) + initial_findings = recon_data.get("initial_findings", []) + + # 构建初始消息 + initial_message = f"""请开始对项目进行安全漏洞分析。 + +## 项目信息 +- 名称: {project_info.get('name', 'unknown')} +- 语言: {tech_stack.get('languages', [])} +- 框架: {tech_stack.get('frameworks', [])} + +## 上下文信息 +### 高风险区域 +{json.dumps(high_risk_areas[:20], ensure_ascii=False)} + +### 入口点 (前10个) +{json.dumps(entry_points[:10], ensure_ascii=False, indent=2)} + +### 初步发现 (如果有) +{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else '无'} + +## 任务 +{task_context or task or '进行全面的安全漏洞分析,发现代码中的安全问题。'} + +## 目标漏洞类型 +{config.get('target_vulnerabilities', ['all'])} + +## 可用工具 +{self._get_tools_description()} + +请开始你的安全分析。首先思考分析策略,然后选择合适的工具开始分析。""" + + # 初始化对话历史 + self._conversation_history = [ + {"role": "system", "content": self.config.system_prompt}, + {"role": "user", "content": initial_message}, + ] + + self._steps = [] + all_findings = [] + + await self.emit_thinking("🔬 Analysis Agent 启动,LLM 开始自主安全分析...") try: - all_findings = [] - - # 1. 静态分析阶段 - if phase_name in ["static_analysis", "analysis"]: - await self.emit_thinking("执行静态代码分析...") - static_findings = await self._run_static_analysis(tech_stack) - all_findings.extend(static_findings) - - # 2. 深度分析阶段 - if phase_name in ["deep_analysis", "analysis"]: - await self.emit_thinking("执行深度漏洞分析...") + for iteration in range(self.config.max_iterations): + if self.is_cancelled: + break - # 分析入口点 - deep_findings = await self._analyze_entry_points(entry_points) - all_findings.extend(deep_findings) + self._iteration = iteration + 1 - # 分析高风险区域(现在会调用 LLM) - risk_findings = await self._analyze_high_risk_areas(high_risk_areas) - all_findings.extend(risk_findings) + # 🔥 发射 LLM 开始思考事件 + await self.emit_llm_start(iteration + 1) - # 语义搜索常见漏洞(现在会调用 LLM) - vuln_types = config.get("target_vulnerabilities", [ - "sql_injection", "xss", "command_injection", - "path_traversal", "ssrf", "hardcoded_secret", - ]) + # 🔥 调用 LLM 进行思考和决策 + response = await self.llm_service.chat_completion_raw( + messages=self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) - for vuln_type in vuln_types[:5]: # 限制数量 - if self.is_cancelled: - break + llm_output = response.get("content", "") + tokens_this_round = response.get("usage", {}).get("total_tokens", 0) + self._total_tokens += tokens_this_round + + # 解析 LLM 响应 + step = self._parse_llm_response(llm_output) + self._steps.append(step) + + # 🔥 发射 LLM 思考内容事件 - 展示安全分析的思考过程 + if step.thought: + await self.emit_llm_thought(step.thought, iteration + 1) + + # 添加 LLM 响应到历史 + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + + # 检查是否完成 + if step.is_final: + await self.emit_llm_decision("完成安全分析", "LLM 判断分析已充分") + if step.final_answer and "findings" in step.final_answer: + all_findings = step.final_answer["findings"] + # 🔥 发射每个发现的事件 + for finding in all_findings[:5]: # 限制数量 + await self.emit_finding( + finding.get("title", "Unknown"), + finding.get("severity", "medium"), + finding.get("vulnerability_type", "other"), + finding.get("file_path", "") + ) + await self.emit_llm_complete( + f"分析完成,发现 {len(all_findings)} 个潜在漏洞", + self._total_tokens + ) + break + + # 执行工具 + if step.action: + # 🔥 发射 LLM 动作决策事件 + await self.emit_llm_action(step.action, step.action_input or {}) - await self.emit_thinking(f"搜索 {vuln_type} 相关代码...") - vuln_findings = await self._search_vulnerability_pattern(vuln_type) - all_findings.extend(vuln_findings) - - # 🔥 3. 如果还没有发现,使用 LLM 进行全面扫描 - if len(all_findings) < 3: - await self.emit_thinking("执行 LLM 全面代码扫描...") - llm_findings = await self._llm_comprehensive_scan(tech_stack) - all_findings.extend(llm_findings) - - # 去重 - all_findings = self._deduplicate_findings(all_findings) + observation = await self._execute_tool( + step.action, + step.action_input or {} + ) + + step.observation = observation + + # 🔥 发射 LLM 观察事件 + await self.emit_llm_observation(observation) + + # 添加观察结果到历史 + self._conversation_history.append({ + "role": "user", + "content": f"Observation:\n{observation}", + }) + else: + # LLM 没有选择工具,提示它继续 + await self.emit_llm_decision("继续分析", "LLM 需要更多分析") + self._conversation_history.append({ + "role": "user", + "content": "请继续分析。选择一个工具执行,或者如果分析完成,输出 Final Answer 汇总所有发现。", + }) + # 处理结果 duration_ms = int((time.time() - start_time) * 1000) + # 标准化发现 + standardized_findings = [] + for finding in all_findings: + standardized = { + "vulnerability_type": finding.get("vulnerability_type", "other"), + "severity": finding.get("severity", "medium"), + "title": finding.get("title", "Unknown Finding"), + "description": finding.get("description", ""), + "file_path": finding.get("file_path", ""), + "line_start": finding.get("line_start") or finding.get("line", 0), + "code_snippet": finding.get("code_snippet", ""), + "source": finding.get("source", ""), + "sink": finding.get("sink", ""), + "suggestion": finding.get("suggestion", ""), + "confidence": finding.get("confidence", 0.7), + "needs_verification": finding.get("needs_verification", True), + } + standardized_findings.append(standardized) + await self.emit_event( "info", - f"分析完成: 发现 {len(all_findings)} 个潜在漏洞" + f"🎯 Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" ) return AgentResult( success=True, - data={"findings": all_findings}, + data={ + "findings": standardized_findings, + "steps": [ + { + "thought": s.thought, + "action": s.action, + "action_input": s.action_input, + "observation": s.observation[:500] if s.observation else None, + } + for s in self._steps + ], + }, iterations=self._iteration, tool_calls=self._tool_calls, tokens_used=self._total_tokens, @@ -191,500 +451,13 @@ class AnalysisAgent(BaseAgent): ) except Exception as e: - logger.error(f"Analysis agent failed: {e}", exc_info=True) + logger.error(f"Analysis Agent failed: {e}", exc_info=True) return AgentResult(success=False, error=str(e)) - async def _run_static_analysis(self, tech_stack: Dict) -> List[Dict]: - """运行静态分析工具""" - findings = [] - - # Semgrep 扫描 - semgrep_tool = self.tools.get("semgrep_scan") - if semgrep_tool: - await self.emit_tool_call("semgrep_scan", {"rules": "p/security-audit"}) - - result = await semgrep_tool.execute(rules="p/security-audit", max_results=30) - - if result.success and result.metadata.get("findings_count", 0) > 0: - for finding in result.metadata.get("findings", []): - findings.append({ - "vulnerability_type": self._map_semgrep_rule(finding.get("check_id", "")), - "severity": self._map_semgrep_severity(finding.get("extra", {}).get("severity", "")), - "title": finding.get("check_id", "Semgrep Finding"), - "description": finding.get("extra", {}).get("message", ""), - "file_path": finding.get("path", ""), - "line_start": finding.get("start", {}).get("line", 0), - "code_snippet": finding.get("extra", {}).get("lines", ""), - "source": "semgrep", - "needs_verification": True, - }) - - # Bandit 扫描 (Python) - languages = tech_stack.get("languages", []) - if "Python" in languages: - bandit_tool = self.tools.get("bandit_scan") - if bandit_tool: - await self.emit_tool_call("bandit_scan", {}) - result = await bandit_tool.execute() - - if result.success and result.metadata.get("findings_count", 0) > 0: - for finding in result.metadata.get("findings", []): - findings.append({ - "vulnerability_type": self._map_bandit_test(finding.get("test_id", "")), - "severity": finding.get("issue_severity", "medium").lower(), - "title": finding.get("test_name", "Bandit Finding"), - "description": finding.get("issue_text", ""), - "file_path": finding.get("filename", ""), - "line_start": finding.get("line_number", 0), - "code_snippet": finding.get("code", ""), - "source": "bandit", - "needs_verification": True, - }) - - return findings + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return self._conversation_history - async def _analyze_entry_points(self, entry_points: List[Dict]) -> List[Dict]: - """分析入口点""" - findings = [] - - code_analysis_tool = self.tools.get("code_analysis") - read_tool = self.tools.get("read_file") - - if not code_analysis_tool or not read_tool: - return findings - - # 分析前几个入口点 - for ep in entry_points[:10]: - if self.is_cancelled: - break - - file_path = ep.get("file", "") - line = ep.get("line", 1) - - if not file_path: - continue - - # 读取文件内容 - read_result = await read_tool.execute( - file_path=file_path, - start_line=max(1, line - 20), - end_line=line + 50, - ) - - if not read_result.success: - continue - - # 深度分析 - analysis_result = await code_analysis_tool.execute( - code=read_result.data, - file_path=file_path, - ) - - if analysis_result.success and analysis_result.metadata.get("issues"): - for issue in analysis_result.metadata["issues"]: - findings.append({ - "vulnerability_type": issue.get("type", "unknown"), - "severity": issue.get("severity", "medium"), - "title": issue.get("title", "Security Issue"), - "description": issue.get("description", ""), - "file_path": file_path, - "line_start": issue.get("line", line), - "code_snippet": issue.get("code_snippet", ""), - "suggestion": issue.get("suggestion", ""), - "source": "code_analysis", - "needs_verification": True, - }) - - return findings - - async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]: - """分析高风险区域 - 使用 LLM 深度分析""" - findings = [] - - read_tool = self.tools.get("read_file") - search_tool = self.tools.get("search_code") - code_analysis_tool = self.tools.get("code_analysis") - - if not search_tool: - return findings - - # 在高风险区域搜索危险模式 - dangerous_patterns = [ - ("execute(", "sql_injection"), - ("query(", "sql_injection"), - ("eval(", "code_injection"), - ("system(", "command_injection"), - ("exec(", "command_injection"), - ("subprocess", "command_injection"), - ("innerHTML", "xss"), - ("document.write", "xss"), - ("open(", "path_traversal"), - ("requests.get", "ssrf"), - ] - - analyzed_files = set() - - for pattern, vuln_type in dangerous_patterns[:8]: - if self.is_cancelled: - break - - result = await search_tool.execute(keyword=pattern, max_results=10) - - if result.success and result.metadata.get("matches", 0) > 0: - for match in result.metadata.get("results", [])[:5]: - file_path = match.get("file", "") - line = match.get("line", 0) - - # 避免重复分析同一个文件的同一区域 - file_key = f"{file_path}:{line // 50}" - if file_key in analyzed_files: - continue - analyzed_files.add(file_key) - - # 🔥 使用 LLM 深度分析找到的代码 - if read_tool and code_analysis_tool: - await self.emit_thinking(f"LLM 分析 {file_path}:{line} 的 {vuln_type} 风险...") - - # 读取代码上下文 - read_result = await read_tool.execute( - file_path=file_path, - start_line=max(1, line - 15), - end_line=line + 25, - ) - - if read_result.success: - # 调用 LLM 分析 - analysis_result = await code_analysis_tool.execute( - code=read_result.data, - file_path=file_path, - focus=vuln_type, - ) - - if analysis_result.success and analysis_result.metadata.get("issues"): - for issue in analysis_result.metadata["issues"]: - findings.append({ - "vulnerability_type": issue.get("type", vuln_type), - "severity": issue.get("severity", "medium"), - "title": issue.get("title", f"LLM 发现: {vuln_type}"), - "description": issue.get("description", ""), - "file_path": file_path, - "line_start": issue.get("line", line), - "code_snippet": issue.get("code_snippet", match.get("match", "")), - "suggestion": issue.get("suggestion", ""), - "ai_explanation": issue.get("ai_explanation", ""), - "source": "llm_analysis", - "needs_verification": True, - }) - elif analysis_result.success: - # LLM 分析了但没发现问题,仍记录原始发现 - findings.append({ - "vulnerability_type": vuln_type, - "severity": "low", - "title": f"疑似 {vuln_type}: {pattern}", - "description": f"在 {file_path} 中发现危险模式,但 LLM 分析未确认", - "file_path": file_path, - "line_start": line, - "code_snippet": match.get("match", ""), - "source": "pattern_search", - "needs_verification": True, - }) - else: - # 没有 LLM 工具,使用基础模式匹配 - findings.append({ - "vulnerability_type": vuln_type, - "severity": "medium", - "title": f"疑似 {vuln_type}: {pattern}", - "description": f"在 {file_path} 中发现危险模式 {pattern}", - "file_path": file_path, - "line_start": line, - "code_snippet": match.get("match", ""), - "source": "pattern_search", - "needs_verification": True, - }) - - return findings - - async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]: - """搜索特定漏洞模式 - 使用 RAG + LLM""" - findings = [] - - security_tool = self.tools.get("security_search") - code_analysis_tool = self.tools.get("code_analysis") - read_tool = self.tools.get("read_file") - - if not security_tool: - return findings - - result = await security_tool.execute( - vulnerability_type=vuln_type, - top_k=10, - ) - - if result.success and result.metadata.get("results_count", 0) > 0: - for item in result.metadata.get("results", [])[:5]: - file_path = item.get("file_path", "") - line_start = item.get("line_start", 0) - content = item.get("content", "")[:2000] - - # 🔥 使用 LLM 验证 RAG 搜索结果 - if code_analysis_tool and content: - await self.emit_thinking(f"LLM 验证 RAG 发现的 {vuln_type}...") - - analysis_result = await code_analysis_tool.execute( - code=content, - file_path=file_path, - focus=vuln_type, - ) - - if analysis_result.success and analysis_result.metadata.get("issues"): - for issue in analysis_result.metadata["issues"]: - findings.append({ - "vulnerability_type": issue.get("type", vuln_type), - "severity": issue.get("severity", "medium"), - "title": issue.get("title", f"LLM 确认: {vuln_type}"), - "description": issue.get("description", ""), - "file_path": file_path, - "line_start": issue.get("line", line_start), - "code_snippet": issue.get("code_snippet", content[:500]), - "suggestion": issue.get("suggestion", ""), - "ai_explanation": issue.get("ai_explanation", ""), - "source": "rag_llm_analysis", - "needs_verification": True, - }) - else: - # RAG 找到但 LLM 未确认 - findings.append({ - "vulnerability_type": vuln_type, - "severity": "low", - "title": f"疑似 {vuln_type} (待确认)", - "description": f"RAG 搜索发现可能存在 {vuln_type},但 LLM 未确认", - "file_path": file_path, - "line_start": line_start, - "code_snippet": content[:500], - "source": "rag_search", - "needs_verification": True, - }) - else: - findings.append({ - "vulnerability_type": vuln_type, - "severity": "medium", - "title": f"疑似 {vuln_type}", - "description": f"通过语义搜索发现可能存在 {vuln_type}", - "file_path": file_path, - "line_start": line_start, - "code_snippet": content[:500], - "source": "rag_search", - "needs_verification": True, - }) - - return findings - - async def _llm_comprehensive_scan(self, tech_stack: Dict) -> List[Dict]: - """ - LLM 全面代码扫描 - 当其他方法没有发现足够的问题时,使用 LLM 直接分析关键文件 - """ - findings = [] - - list_tool = self.tools.get("list_files") - read_tool = self.tools.get("read_file") - code_analysis_tool = self.tools.get("code_analysis") - - if not all([list_tool, read_tool, code_analysis_tool]): - return findings - - await self.emit_thinking("LLM 全面扫描关键代码文件...") - - # 确定要扫描的文件类型 - languages = tech_stack.get("languages", []) - file_patterns = [] - - if "Python" in languages: - file_patterns.extend(["*.py"]) - if "JavaScript" in languages or "TypeScript" in languages: - file_patterns.extend(["*.js", "*.ts"]) - if "Go" in languages: - file_patterns.extend(["*.go"]) - if "Java" in languages: - file_patterns.extend(["*.java"]) - if "PHP" in languages: - file_patterns.extend(["*.php"]) - - if not file_patterns: - file_patterns = ["*.py", "*.js", "*.ts", "*.go", "*.java", "*.php"] - - # 扫描关键目录 - key_dirs = ["src", "app", "api", "routes", "controllers", "handlers", "lib", "utils", "."] - scanned_files = 0 - max_files_to_scan = 10 - - for key_dir in key_dirs: - if scanned_files >= max_files_to_scan or self.is_cancelled: - break - - for pattern in file_patterns[:3]: - if scanned_files >= max_files_to_scan or self.is_cancelled: - break - - # 列出文件 - list_result = await list_tool.execute( - directory=key_dir, - pattern=pattern, - recursive=True, - max_files=20, - ) - - if not list_result.success: - continue - - # 从输出中提取文件路径 - output = list_result.data - file_paths = [] - for line in output.split('\n'): - line = line.strip() - if line.startswith('📄 '): - file_paths.append(line[2:].strip()) - - # 分析每个文件 - for file_path in file_paths[:5]: - if scanned_files >= max_files_to_scan or self.is_cancelled: - break - - # 跳过测试文件和配置文件 - if any(skip in file_path.lower() for skip in ['test', 'spec', 'mock', '__pycache__', 'node_modules']): - continue - - await self.emit_thinking(f"LLM 分析文件: {file_path}") - - # 读取文件 - read_result = await read_tool.execute( - file_path=file_path, - max_lines=200, - ) - - if not read_result.success: - continue - - scanned_files += 1 - - # 🔥 LLM 深度分析 - analysis_result = await code_analysis_tool.execute( - code=read_result.data, - file_path=file_path, - ) - - if analysis_result.success and analysis_result.metadata.get("issues"): - for issue in analysis_result.metadata["issues"]: - findings.append({ - "vulnerability_type": issue.get("type", "other"), - "severity": issue.get("severity", "medium"), - "title": issue.get("title", "LLM 发现的安全问题"), - "description": issue.get("description", ""), - "file_path": file_path, - "line_start": issue.get("line", 0), - "code_snippet": issue.get("code_snippet", ""), - "suggestion": issue.get("suggestion", ""), - "ai_explanation": issue.get("ai_explanation", ""), - "source": "llm_comprehensive_scan", - "needs_verification": True, - }) - - await self.emit_thinking(f"LLM 全面扫描完成,分析了 {scanned_files} 个文件") - return findings - - def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]: - """去重发现""" - seen = set() - unique = [] - - for finding in findings: - key = ( - finding.get("file_path", ""), - finding.get("line_start", 0), - finding.get("vulnerability_type", ""), - ) - - if key not in seen: - seen.add(key) - unique.append(finding) - - return unique - - def _map_semgrep_rule(self, rule_id: str) -> str: - """映射 Semgrep 规则到漏洞类型""" - rule_lower = rule_id.lower() - - if "sql" in rule_lower: - return "sql_injection" - elif "xss" in rule_lower: - return "xss" - elif "command" in rule_lower or "injection" in rule_lower: - return "command_injection" - elif "path" in rule_lower or "traversal" in rule_lower: - return "path_traversal" - elif "ssrf" in rule_lower: - return "ssrf" - elif "deserial" in rule_lower: - return "deserialization" - elif "secret" in rule_lower or "password" in rule_lower or "key" in rule_lower: - return "hardcoded_secret" - elif "crypto" in rule_lower: - return "weak_crypto" - else: - return "other" - - def _map_semgrep_severity(self, severity: str) -> str: - """映射 Semgrep 严重程度""" - mapping = { - "ERROR": "high", - "WARNING": "medium", - "INFO": "low", - } - return mapping.get(severity, "medium") - - def _map_bandit_test(self, test_id: str) -> str: - """映射 Bandit 测试到漏洞类型""" - mappings = { - "B101": "assert_used", - "B102": "exec_used", - "B103": "hardcoded_password", - "B104": "hardcoded_bind_all", - "B105": "hardcoded_password", - "B106": "hardcoded_password", - "B107": "hardcoded_password", - "B108": "hardcoded_tmp", - "B301": "deserialization", - "B302": "deserialization", - "B303": "weak_crypto", - "B304": "weak_crypto", - "B305": "weak_crypto", - "B306": "weak_crypto", - "B307": "code_injection", - "B308": "code_injection", - "B310": "ssrf", - "B311": "weak_random", - "B312": "telnet", - "B501": "ssl_verify", - "B502": "ssl_verify", - "B503": "ssl_verify", - "B504": "ssl_verify", - "B505": "weak_crypto", - "B506": "yaml_load", - "B507": "ssh_key", - "B601": "command_injection", - "B602": "command_injection", - "B603": "command_injection", - "B604": "command_injection", - "B605": "command_injection", - "B606": "command_injection", - "B607": "command_injection", - "B608": "sql_injection", - "B609": "sql_injection", - "B610": "sql_injection", - "B611": "sql_injection", - "B701": "xss", - "B702": "xss", - "B703": "xss", - } - return mappings.get(test_id, "other") - + def get_steps(self) -> List[AnalysisStep]: + """获取执行步骤""" + return self._steps diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index eb9ae1e..d526c24 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -1,6 +1,8 @@ """ Agent 基类 定义 Agent 的基本接口和通用功能 + +核心原则:LLM 是 Agent 的大脑,所有日志应该反映 LLM 的参与! """ from abc import ABC, abstractmethod @@ -87,7 +89,11 @@ class AgentResult: class BaseAgent(ABC): """ Agent 基类 - 所有 Agent 需要继承此类并实现核心方法 + + 核心原则: + 1. LLM 是 Agent 的大脑,全程参与决策 + 2. 所有日志应该反映 LLM 的思考过程 + 3. 工具调用是 LLM 的决策结果 """ def __init__( @@ -146,6 +152,8 @@ class BaseAgent(ABC): def is_cancelled(self) -> bool: return self._cancelled + # ============ 核心事件发射方法 ============ + async def emit_event( self, event_type: str, @@ -161,28 +169,123 @@ class BaseAgent(ABC): **kwargs )) + # ============ LLM 思考相关事件 ============ + async def emit_thinking(self, message: str): - """发射思考事件""" - await self.emit_event("thinking", f"[{self.name}] {message}") + """发射 LLM 思考事件""" + await self.emit_event("thinking", f"🧠 [{self.name}] {message}") + + async def emit_llm_start(self, iteration: int): + """发射 LLM 开始思考事件""" + await self.emit_event( + "llm_start", + f"🤔 [{self.name}] LLM 开始第 {iteration} 轮思考...", + metadata={"iteration": iteration} + ) + + async def emit_llm_thought(self, thought: str, iteration: int): + """发射 LLM 思考内容事件 - 这是核心!展示 LLM 在想什么""" + # 截断过长的思考内容 + display_thought = thought[:500] + "..." if len(thought) > 500 else thought + await self.emit_event( + "llm_thought", + f"💭 [{self.name}] LLM 思考:\n{display_thought}", + metadata={ + "thought": thought, + "iteration": iteration, + } + ) + + async def emit_llm_decision(self, decision: str, reason: str = ""): + """发射 LLM 决策事件 - 展示 LLM 做了什么决定""" + await self.emit_event( + "llm_decision", + f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""), + metadata={ + "decision": decision, + "reason": reason, + } + ) + + async def emit_llm_action(self, action: str, action_input: Dict): + """发射 LLM 动作事件 - LLM 决定执行什么动作""" + import json + input_str = json.dumps(action_input, ensure_ascii=False)[:200] + await self.emit_event( + "llm_action", + f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}", + metadata={ + "action": action, + "action_input": action_input, + } + ) + + async def emit_llm_observation(self, observation: str): + """发射 LLM 观察事件 - LLM 看到了什么""" + display_obs = observation[:300] + "..." if len(observation) > 300 else observation + await self.emit_event( + "llm_observation", + f"👁️ [{self.name}] LLM 观察到:\n{display_obs}", + metadata={"observation": observation[:2000]} + ) + + async def emit_llm_complete(self, result_summary: str, tokens_used: int): + """发射 LLM 完成事件""" + await self.emit_event( + "llm_complete", + f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)", + metadata={ + "tokens_used": tokens_used, + } + ) + + # ============ 工具调用相关事件 ============ async def emit_tool_call(self, tool_name: str, tool_input: Dict): - """发射工具调用事件""" + """发射工具调用事件 - LLM 决定调用工具""" + import json + input_str = json.dumps(tool_input, ensure_ascii=False)[:300] await self.emit_event( "tool_call", - f"[{self.name}] 调用工具: {tool_name}", + f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}", tool_name=tool_name, tool_input=tool_input, ) async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int): """发射工具结果事件""" + result_preview = result[:200] + "..." if len(result) > 200 else result await self.emit_event( "tool_result", - f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)", + f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}", tool_name=tool_name, tool_duration_ms=duration_ms, ) + # ============ 发现相关事件 ============ + + async def emit_finding(self, title: str, severity: str, vuln_type: str, file_path: str = ""): + """发射漏洞发现事件""" + severity_emoji = { + "critical": "🔴", + "high": "🟠", + "medium": "🟡", + "low": "🟢", + }.get(severity.lower(), "⚪") + + await self.emit_event( + "finding", + f"{severity_emoji} [{self.name}] 发现漏洞: [{severity.upper()}] {title}\n 类型: {vuln_type}\n 位置: {file_path}", + metadata={ + "title": title, + "severity": severity, + "vulnerability_type": vuln_type, + "file_path": file_path, + } + ) + + # ============ 通用工具方法 ============ + async def call_tool(self, tool_name: str, **kwargs) -> Any: """ 调用工具 @@ -208,7 +311,7 @@ class BaseAgent(ABC): result = await tool.execute(**kwargs) duration_ms = int((time.time() - start) * 1000) - await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + await self.emit_tool_result(tool_name, str(result.data)[:500], duration_ms) return result @@ -229,8 +332,9 @@ class BaseAgent(ABC): """ self._iteration += 1 - # 这里应该调用实际的 LLM 服务 - # 使用 LangChain 或直接调用 API + # 发射 LLM 开始事件 + await self.emit_llm_start(self._iteration) + try: response = await self.llm_service.chat_completion( messages=messages, @@ -281,4 +385,3 @@ class BaseAgent(ABC): "tool_calls": self._tool_calls, "tokens_used": self._total_tokens, } - diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index 5ccae14..9354ac5 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -1,12 +1,19 @@ """ -Orchestrator Agent (编排层) -负责任务分解、子 Agent 调度和结果汇总 +Orchestrator Agent (编排层) - LLM 驱动版 -类型: Plan-and-Execute +LLM 是真正的大脑,全程参与决策! +- LLM 决定下一步做什么 +- LLM 决定调度哪个子 Agent +- LLM 决定何时完成 +- LLM 根据中间结果动态调整策略 + +类型: Autonomous Agent with Dynamic Planning """ import asyncio +import json import logging +import re from typing import List, Dict, Any, Optional from dataclasses import dataclass @@ -15,69 +22,97 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern logger = logging.getLogger(__name__) -@dataclass -class AuditPlan: - """审计计划""" - phases: List[Dict[str, Any]] - high_risk_areas: List[str] - focus_vulnerabilities: List[str] - estimated_steps: int - priority_files: List[str] - metadata: Dict[str, Any] +ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责**自主**协调整个安全审计流程。 +## 你的角色 +你是整个审计流程的**大脑**,不是一个机械执行者。你需要: +1. 自主思考和决策 +2. 根据观察结果动态调整策略 +3. 决定何时调用哪个子 Agent +4. 判断何时审计完成 -ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责协调整个安全审计流程。 +## 你可以调度的子 Agent +1. **recon**: 信息收集 Agent - 分析项目结构、技术栈、入口点 +2. **analysis**: 分析 Agent - 深度代码审计、漏洞检测 +3. **verification**: 验证 Agent - 验证发现的漏洞、生成 PoC -## 你的职责 -1. 分析项目信息,制定审计计划 -2. 调度子 Agent(Recon、Analysis、Verification)执行任务 -3. 汇总审计结果,生成报告 +## 你可以使用的操作 -## 审计流程 -1. **信息收集阶段**: 调度 Recon Agent 收集项目信息 - - 项目结构分析 - - 技术栈识别 - - 入口点识别 - - 依赖分析 - -2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析 - - 静态代码分析 - - 语义搜索 - - 模式匹配 - - 数据流追踪 - -3. **漏洞验证阶段**: 调度 Verification Agent 验证发现 - - 漏洞确认 - - PoC 生成 - - 沙箱测试 - -4. **报告生成阶段**: 汇总所有发现,生成最终报告 - -## 输出格式 -当生成审计计划时,返回 JSON: -```json -{ - "phases": [ - {"name": "阶段名", "description": "描述", "agent": "agent_type"} - ], - "high_risk_areas": ["高风险目录/文件"], - "focus_vulnerabilities": ["重点漏洞类型"], - "priority_files": ["优先审计的文件"], - "estimated_steps": 数字 -} +### 1. 调度子 Agent +``` +Action: dispatch_agent +Action Input: {"agent": "recon|analysis|verification", "task": "具体任务描述", "context": "任务上下文"} ``` -请基于项目信息制定合理的审计计划。""" +### 2. 汇总发现 +``` +Action: summarize +Action Input: {"findings": [...], "analysis": "你的分析"} +``` + +### 3. 完成审计 +``` +Action: finish +Action Input: {"conclusion": "审计结论", "findings": [...], "recommendations": [...]} +``` + +## 工作方式 +每一步,你需要: + +1. **Thought**: 分析当前状态,思考下一步应该做什么 + - 目前收集到了什么信息? + - 还需要了解什么? + - 应该深入分析哪些地方? + - 有什么发现需要验证? + +2. **Action**: 选择一个操作 +3. **Action Input**: 提供操作参数 + +## 输出格式 +每一步必须严格按照以下格式: + +``` +Thought: [你的思考过程] +Action: [dispatch_agent|summarize|finish] +Action Input: [JSON 参数] +``` + +## 审计策略建议 +- 先用 recon Agent 了解项目全貌 +- 根据 recon 结果,让 analysis Agent 重点审计高风险区域 +- 发现可疑漏洞后,用 verification Agent 验证 +- 随时根据新发现调整策略,不要机械执行 +- 当你认为审计足够全面时,选择 finish + +## 重要原则 +1. **你是大脑,不是执行器** - 每一步都要思考 +2. **动态调整** - 根据发现调整策略 +3. **主动决策** - 不要等待,主动推进 +4. **质量优先** - 宁可深入分析几个真实漏洞,不要浅尝辄止 + +现在,基于项目信息开始你的审计工作!""" + + +@dataclass +class AgentStep: + """执行步骤""" + thought: str + action: str + action_input: Dict[str, Any] + observation: Optional[str] = None + sub_agent_result: Optional[AgentResult] = None class OrchestratorAgent(BaseAgent): """ - 编排 Agent + 编排 Agent - LLM 驱动版 - 使用 Plan-and-Execute 模式: - 1. 首先生成审计计划 - 2. 按计划调度子 Agent - 3. 收集结果并汇总 + LLM 全程参与决策: + 1. LLM 思考当前状态 + 2. LLM 决定下一步操作 + 3. 执行操作,获取结果 + 4. LLM 分析结果,决定下一步 + 5. 重复直到 LLM 决定完成 """ def __init__( @@ -90,13 +125,16 @@ class OrchestratorAgent(BaseAgent): config = AgentConfig( name="Orchestrator", agent_type=AgentType.ORCHESTRATOR, - pattern=AgentPattern.PLAN_AND_EXECUTE, - max_iterations=10, + pattern=AgentPattern.REACT, # 改为 ReAct 模式! + max_iterations=20, system_prompt=ORCHESTRATOR_SYSTEM_PROMPT, ) super().__init__(config, llm_service, tools, event_emitter) self.sub_agents = sub_agents or {} + self._conversation_history: List[Dict[str, str]] = [] + self._steps: List[AgentStep] = [] + self._all_findings: List[Dict] = [] def register_sub_agent(self, name: str, agent: BaseAgent): """注册子 Agent""" @@ -104,7 +142,7 @@ class OrchestratorAgent(BaseAgent): async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ - 执行编排任务 + 执行编排任务 - LLM 全程参与! Args: input_data: { @@ -118,82 +156,136 @@ class OrchestratorAgent(BaseAgent): project_info = input_data.get("project_info", {}) config = input_data.get("config", {}) + # 构建初始消息 + initial_message = self._build_initial_message(project_info, config) + + # 初始化对话历史 + self._conversation_history = [ + {"role": "system", "content": self.config.system_prompt}, + {"role": "user", "content": initial_message}, + ] + + self._steps = [] + self._all_findings = [] + final_result = None + + await self.emit_thinking("🧠 Orchestrator Agent 启动,LLM 开始自主编排决策...") + try: - await self.emit_thinking("开始制定审计计划...") - - # 1. 生成审计计划 - plan = await self._create_audit_plan(project_info, config) - - if not plan: - return AgentResult( - success=False, - error="无法生成审计计划", - ) - - await self.emit_event( - "planning", - f"审计计划已生成,共 {len(plan.phases)} 个阶段", - metadata={"plan": plan.__dict__} - ) - - # 2. 执行各阶段 - all_findings = [] - phase_results = {} - - for phase in plan.phases: + for iteration in range(self.config.max_iterations): if self.is_cancelled: break - phase_name = phase.get("name", "unknown") - agent_type = phase.get("agent", "analysis") + self._iteration = iteration + 1 - await self.emit_event( - "phase_start", - f"开始 {phase_name} 阶段", - phase=phase_name + # 🔥 发射 LLM 开始思考事件 + await self.emit_llm_start(iteration + 1) + + # 🔥 调用 LLM 进行思考和决策 + response = await self.llm_service.chat_completion_raw( + messages=self._conversation_history, + temperature=0.1, + max_tokens=2048, ) - # 调度对应的子 Agent - result = await self._execute_phase( - phase_name=phase_name, - agent_type=agent_type, - project_info=project_info, - config=config, - plan=plan, - previous_results=phase_results, - ) + llm_output = response.get("content", "") + tokens_this_round = response.get("usage", {}).get("total_tokens", 0) + self._total_tokens += tokens_this_round - phase_results[phase_name] = result + # 解析 LLM 的决策 + step = self._parse_llm_response(llm_output) - if result.success and result.data: - if isinstance(result.data, dict): - findings = result.data.get("findings", []) - all_findings.extend(findings) + if not step: + # LLM 输出格式不正确,提示重试 + await self.emit_llm_decision("格式错误", "需要重新输出") + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + self._conversation_history.append({ + "role": "user", + "content": "请按照规定格式输出:Thought + Action + Action Input", + }) + continue - await self.emit_event( - "phase_complete", - f"{phase_name} 阶段完成", - phase=phase_name - ) - - # 3. 汇总结果 - await self.emit_thinking("汇总审计结果...") - - summary = await self._generate_summary( - plan=plan, - phase_results=phase_results, - all_findings=all_findings, - ) + self._steps.append(step) + + # 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程 + if step.thought: + await self.emit_llm_thought(step.thought, iteration + 1) + + # 添加 LLM 响应到历史 + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + + # 执行 LLM 决定的操作 + if step.action == "finish": + # 🔥 LLM 决定完成审计 + await self.emit_llm_decision("完成审计", "LLM 判断审计已充分完成") + await self.emit_llm_complete( + f"编排完成,发现 {len(self._all_findings)} 个漏洞", + self._total_tokens + ) + final_result = step.action_input + break + + elif step.action == "dispatch_agent": + # 🔥 LLM 决定调度子 Agent + agent_name = step.action_input.get("agent", "unknown") + task_desc = step.action_input.get("task", "") + await self.emit_llm_decision( + f"调度 {agent_name} Agent", + f"任务: {task_desc[:100]}" + ) + await self.emit_llm_action("dispatch_agent", step.action_input) + + observation = await self._dispatch_agent(step.action_input) + step.observation = observation + + # 🔥 发射观察事件 + await self.emit_llm_observation(observation) + + elif step.action == "summarize": + # LLM 要求汇总 + await self.emit_llm_decision("汇总发现", "LLM 请求查看当前发现汇总") + observation = self._summarize_findings() + step.observation = observation + await self.emit_llm_observation(observation) + + else: + observation = f"未知操作: {step.action},可用操作: dispatch_agent, summarize, finish" + await self.emit_llm_decision("未知操作", observation) + + # 添加观察结果到历史 + self._conversation_history.append({ + "role": "user", + "content": f"Observation:\n{step.observation}", + }) + # 生成最终结果 duration_ms = int((time.time() - start_time) * 1000) + await self.emit_event( + "info", + f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策" + ) + return AgentResult( success=True, data={ - "plan": plan.__dict__, - "findings": all_findings, - "summary": summary, - "phase_results": {k: v.to_dict() for k, v in phase_results.items()}, + "findings": self._all_findings, + "summary": final_result or self._generate_default_summary(), + "steps": [ + { + "thought": s.thought, + "action": s.action, + "action_input": s.action_input, + "observation": s.observation[:500] if s.observation else None, + } + for s in self._steps + ], }, iterations=self._iteration, tool_calls=self._tool_calls, @@ -208,174 +300,197 @@ class OrchestratorAgent(BaseAgent): error=str(e), ) - async def _create_audit_plan( + def _build_initial_message( self, project_info: Dict[str, Any], config: Dict[str, Any], - ) -> Optional[AuditPlan]: - """生成审计计划""" - # 构建 prompt - prompt = f"""基于以下项目信息,制定安全审计计划。 + ) -> str: + """构建初始消息""" + msg = f"""请开始对以下项目进行安全审计。 ## 项目信息 - 名称: {project_info.get('name', 'unknown')} - 语言: {project_info.get('languages', [])} - 文件数量: {project_info.get('file_count', 0)} -- 目录结构: {project_info.get('structure', {})} +- 目录结构: {json.dumps(project_info.get('structure', {}), ensure_ascii=False, indent=2)} ## 用户配置 -- 目标漏洞: {config.get('target_vulnerabilities', [])} +- 目标漏洞: {config.get('target_vulnerabilities', ['all'])} - 验证级别: {config.get('verification_level', 'sandbox')} - 排除模式: {config.get('exclude_patterns', [])} -请生成审计计划,返回 JSON 格式。""" +## 可用子 Agent +{', '.join(self.sub_agents.keys()) if self.sub_agents else '(暂无子 Agent)'} + +请开始你的审计工作。首先思考应该如何开展,然后决定第一步做什么。""" + + return msg + + def _parse_llm_response(self, response: str) -> Optional[AgentStep]: + """解析 LLM 响应""" + # 提取 Thought + thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|$)', response, re.DOTALL) + thought = thought_match.group(1).strip() if thought_match else "" + + # 提取 Action + action_match = re.search(r'Action:\s*(\w+)', response) + if not action_match: + return None + action = action_match.group(1).strip() + + # 提取 Action Input + input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Observation:|$)', response, re.DOTALL) + if not input_match: + return None + + input_text = input_match.group(1).strip() + # 移除 markdown 代码块 + input_text = re.sub(r'```json\s*', '', input_text) + input_text = re.sub(r'```\s*', '', input_text) try: - # 调用 LLM - messages = [ - {"role": "system", "content": self.config.system_prompt}, - {"role": "user", "content": prompt}, - ] - - response = await self.llm_service.chat_completion_raw( - messages=messages, - temperature=0.1, - max_tokens=2000, - ) - - content = response.get("content", "") - - # 解析 JSON - import json - import re - - # 提取 JSON - json_match = re.search(r'\{[\s\S]*\}', content) - if json_match: - plan_data = json.loads(json_match.group()) - - return AuditPlan( - phases=plan_data.get("phases", self._default_phases()), - high_risk_areas=plan_data.get("high_risk_areas", []), - focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []), - estimated_steps=plan_data.get("estimated_steps", 30), - priority_files=plan_data.get("priority_files", []), - metadata=plan_data, - ) - else: - # 使用默认计划 - return AuditPlan( - phases=self._default_phases(), - high_risk_areas=["src/", "api/", "controllers/", "routes/"], - focus_vulnerabilities=["sql_injection", "xss", "command_injection"], - estimated_steps=30, - priority_files=[], - metadata={}, - ) - - except Exception as e: - logger.error(f"Failed to create audit plan: {e}") - return AuditPlan( - phases=self._default_phases(), - high_risk_areas=[], - focus_vulnerabilities=[], - estimated_steps=30, - priority_files=[], - metadata={}, - ) + action_input = json.loads(input_text) + except json.JSONDecodeError: + action_input = {"raw": input_text} + + return AgentStep( + thought=thought, + action=action, + action_input=action_input, + ) - def _default_phases(self) -> List[Dict[str, Any]]: - """默认审计阶段""" - return [ - { - "name": "recon", - "description": "信息收集 - 分析项目结构和技术栈", - "agent": "recon", - }, - { - "name": "static_analysis", - "description": "静态分析 - 使用外部工具快速扫描", - "agent": "analysis", - }, - { - "name": "deep_analysis", - "description": "深度分析 - AI 驱动的代码审计", - "agent": "analysis", - }, - { - "name": "verification", - "description": "漏洞验证 - 确认发现的漏洞", - "agent": "verification", - }, - ] - - async def _execute_phase( - self, - phase_name: str, - agent_type: str, - project_info: Dict[str, Any], - config: Dict[str, Any], - plan: AuditPlan, - previous_results: Dict[str, AgentResult], - ) -> AgentResult: - """执行审计阶段""" - agent = self.sub_agents.get(agent_type) + async def _dispatch_agent(self, params: Dict[str, Any]) -> str: + """调度子 Agent""" + agent_name = params.get("agent", "") + task = params.get("task", "") + context = params.get("context", "") + + agent = self.sub_agents.get(agent_name) if not agent: - logger.warning(f"Agent not found: {agent_type}") - return AgentResult(success=False, error=f"Agent {agent_type} not found") + available = list(self.sub_agents.keys()) + return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}" - # 构建阶段输入 - phase_input = { - "phase_name": phase_name, - "project_info": project_info, - "config": config, - "plan": plan.__dict__, - "previous_results": {k: v.to_dict() for k, v in previous_results.items()}, - } + await self.emit_event( + "dispatch", + f"📤 调度 {agent_name} Agent: {task[:100]}...", + agent=agent_name, + task=task, + ) - # 执行子 Agent - return await agent.run(phase_input) + self._tool_calls += 1 + + try: + # 构建子 Agent 输入 + sub_input = { + "task": task, + "task_context": context, + "project_info": {}, # 从上下文获取 + "config": {}, + } + + # 执行子 Agent + result = await agent.run(sub_input) + + # 收集发现 + if result.success and result.data: + findings = result.data.get("findings", []) + self._all_findings.extend(findings) + + await self.emit_event( + "dispatch_complete", + f"✅ {agent_name} Agent 完成: {len(findings)} 个发现", + agent=agent_name, + findings_count=len(findings), + ) + + # 构建观察结果 + observation = f"""## {agent_name} Agent 执行结果 + +**状态**: 成功 +**发现数量**: {len(findings)} +**迭代次数**: {result.iterations} +**耗时**: {result.duration_ms}ms + +### 发现摘要 +""" + for i, f in enumerate(findings[:10]): # 最多显示 10 个 + observation += f""" +{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')} + - 类型: {f.get('vulnerability_type', 'unknown')} + - 文件: {f.get('file_path', 'unknown')} + - 描述: {f.get('description', '')[:200]}... +""" + + if len(findings) > 10: + observation += f"\n... 还有 {len(findings) - 10} 个发现" + + if result.data.get("summary"): + observation += f"\n\n### Agent 总结\n{result.data['summary']}" + + return observation + else: + return f"## {agent_name} Agent 执行失败\n\n错误: {result.error}" + + except Exception as e: + logger.error(f"Sub-agent dispatch failed: {e}", exc_info=True) + return f"## 调度失败\n\n错误: {str(e)}" - async def _generate_summary( - self, - plan: AuditPlan, - phase_results: Dict[str, AgentResult], - all_findings: List[Dict], - ) -> Dict[str, Any]: - """生成审计摘要""" - # 统计漏洞 + def _summarize_findings(self) -> str: + """汇总当前发现""" + if not self._all_findings: + return "目前还没有发现任何漏洞。" + + # 统计 severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} type_counts = {} - verified_count = 0 - for finding in all_findings: - sev = finding.get("severity", "low") + for f in self._all_findings: + sev = f.get("severity", "low") severity_counts[sev] = severity_counts.get(sev, 0) + 1 - vtype = finding.get("vulnerability_type", "other") + vtype = f.get("vulnerability_type", "other") type_counts[vtype] = type_counts.get(vtype, 0) + 1 - - if finding.get("is_verified"): - verified_count += 1 - # 计算安全评分 - base_score = 100 - deductions = ( - severity_counts["critical"] * 20 + - severity_counts["high"] * 10 + - severity_counts["medium"] * 5 + - severity_counts["low"] * 2 - ) - security_score = max(0, base_score - deductions) + summary = f"""## 当前发现汇总 + +**总计**: {len(self._all_findings)} 个漏洞 + +### 严重程度分布 +- Critical: {severity_counts['critical']} +- High: {severity_counts['high']} +- Medium: {severity_counts['medium']} +- Low: {severity_counts['low']} + +### 漏洞类型分布 +""" + for vtype, count in type_counts.items(): + summary += f"- {vtype}: {count}\n" + + summary += "\n### 详细列表\n" + for i, f in enumerate(self._all_findings): + summary += f"{i+1}. [{f.get('severity')}] {f.get('title')} ({f.get('file_path')})\n" + + return summary + + def _generate_default_summary(self) -> Dict[str, Any]: + """生成默认摘要""" + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} + + for f in self._all_findings: + sev = f.get("severity", "low") + severity_counts[sev] = severity_counts.get(sev, 0) + 1 return { - "total_findings": len(all_findings), - "verified_count": verified_count, + "total_findings": len(self._all_findings), "severity_distribution": severity_counts, - "vulnerability_types": type_counts, - "security_score": security_score, - "phases_completed": len(phase_results), - "high_risk_areas": plan.high_risk_areas, + "conclusion": "审计完成(未通过 LLM 生成结论)", } - + + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return self._conversation_history + + def get_steps(self) -> List[AgentStep]: + """获取执行步骤""" + return self._steps diff --git a/backend/app/services/agent/agents/recon.py b/backend/app/services/agent/agents/recon.py index 49b08e2..8ddc609 100644 --- a/backend/app/services/agent/agents/recon.py +++ b/backend/app/services/agent/agents/recon.py @@ -1,72 +1,127 @@ """ -Recon Agent (信息收集层) -负责项目结构分析、技术栈识别、入口点识别 +Recon Agent (信息收集层) - LLM 驱动版 -类型: ReAct +LLM 是真正的大脑! +- LLM 决定收集什么信息 +- LLM 决定使用哪个工具 +- LLM 决定何时信息足够 +- LLM 动态调整收集策略 + +类型: ReAct (真正的!) """ -import asyncio +import json import logging -import os +import re from typing import List, Dict, Any, Optional -from dataclasses import dataclass, field +from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern logger = logging.getLogger(__name__) -RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。 +RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前**自主**收集项目信息。 -## 你的职责 -1. 分析项目结构和目录布局 -2. 识别使用的技术栈和框架 -3. 找出应用程序入口点 -4. 分析依赖和第三方库 -5. 识别高风险区域 +## 你的角色 +你是信息收集的**大脑**,不是机械执行者。你需要: +1. 自主思考需要收集什么信息 +2. 选择合适的工具获取信息 +3. 根据发现动态调整策略 +4. 判断何时信息收集足够 ## 你可以使用的工具 -- list_files: 列出目录内容 -- read_file: 读取文件内容 -- search_code: 搜索代码 -- semgrep_scan: Semgrep 扫描 -- npm_audit: npm 依赖审计 -- safety_scan: Python 依赖审计 -- gitleaks_scan: 密钥泄露扫描 -## 信息收集要点 -1. **目录结构**: 了解项目布局,识别源码、配置、测试目录 -2. **技术栈**: 检测语言、框架、数据库等 -3. **入口点**: API 路由、控制器、处理函数 -4. **配置文件**: 环境变量、数据库配置、API 密钥 -5. **依赖**: package.json, requirements.txt, go.mod 等 -6. **安全相关**: 认证、授权、加密相关代码 +### 文件系统 +- **list_files**: 列出目录内容 + 参数: directory (str), recursive (bool), pattern (str), max_files (int) + +- **read_file**: 读取文件内容 + 参数: file_path (str), start_line (int), end_line (int), max_lines (int) + +- **search_code**: 代码关键字搜索 + 参数: keyword (str), max_results (int) -## 输出格式 -完成后返回 JSON: +### 安全扫描 +- **semgrep_scan**: Semgrep 静态分析扫描 +- **npm_audit**: npm 依赖漏洞审计 +- **safety_scan**: Python 依赖漏洞审计 +- **gitleaks_scan**: 密钥/敏感信息泄露扫描 +- **osv_scan**: OSV 通用依赖漏洞扫描 + +## 工作方式 +每一步,你需要输出: + +``` +Thought: [分析当前状态,思考还需要什么信息] +Action: [工具名称] +Action Input: [JSON 格式的参数] +``` + +当你认为信息收集足够时,输出: + +``` +Thought: [总结收集到的信息] +Final Answer: [JSON 格式的收集结果] +``` + +## Final Answer 格式 ```json { - "project_structure": {...}, + "project_structure": { + "directories": [], + "config_files": [], + "total_files": 数量 + }, "tech_stack": { "languages": [], "frameworks": [], "databases": [] }, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {...}, + "entry_points": [ + {"type": "描述", "file": "路径", "line": 行号} + ], + "high_risk_areas": ["路径列表"], + "dependencies": {}, "initial_findings": [] } ``` -请系统性地收集信息,为后续分析做准备。""" +## 信息收集策略建议 +1. 先 list_files 了解项目结构 +2. 读取配置文件 (package.json, requirements.txt, go.mod 等) 识别技术栈 +3. 搜索入口点模式 (routes, controllers, handlers) +4. 运行安全扫描发现初步问题 +5. 根据发现继续深入 + +## 重要原则 +1. **你是大脑** - 每一步都要思考,不要机械执行 +2. **动态调整** - 根据发现调整策略 +3. **效率优先** - 不要重复收集已有信息 +4. **主动探索** - 发现有趣的东西要深入 + +现在开始收集项目信息!""" + + +@dataclass +class ReconStep: + """信息收集步骤""" + thought: str + action: Optional[str] = None + action_input: Optional[Dict] = None + observation: Optional[str] = None + is_final: bool = False + final_answer: Optional[Dict] = None class ReconAgent(BaseAgent): """ - 信息收集 Agent + 信息收集 Agent - LLM 驱动版 - 使用 ReAct 模式迭代收集项目信息 + LLM 全程参与,自主决定: + 1. 收集什么信息 + 2. 使用什么工具 + 3. 何时足够 """ def __init__( @@ -81,84 +136,219 @@ class ReconAgent(BaseAgent): pattern=AgentPattern.REACT, max_iterations=15, system_prompt=RECON_SYSTEM_PROMPT, - tools=[ - "list_files", "read_file", "search_code", - "semgrep_scan", "npm_audit", "safety_scan", - "gitleaks_scan", "osv_scan", - ], ) super().__init__(config, llm_service, tools, event_emitter) + + self._conversation_history: List[Dict[str, str]] = [] + self._steps: List[ReconStep] = [] + + def _get_tools_description(self) -> str: + """生成工具描述""" + tools_info = [] + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + desc = f"- {name}: {getattr(tool, 'description', 'No description')}" + tools_info.append(desc) + return "\n".join(tools_info) + + def _parse_llm_response(self, response: str) -> ReconStep: + """解析 LLM 响应""" + step = ReconStep(thought="") + + # 提取 Thought + thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL) + if thought_match: + step.thought = thought_match.group(1).strip() + + # 检查是否是最终答案 + final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) + if final_match: + step.is_final = True + try: + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + step.final_answer = json.loads(answer_text) + except json.JSONDecodeError: + step.final_answer = {"raw_answer": final_match.group(1).strip()} + return step + + # 提取 Action + action_match = re.search(r'Action:\s*(\w+)', response) + if action_match: + step.action = action_match.group(1).strip() + + # 提取 Action Input + input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) + if input_match: + input_text = input_match.group(1).strip() + input_text = re.sub(r'```json\s*', '', input_text) + input_text = re.sub(r'```\s*', '', input_text) + try: + step.action_input = json.loads(input_text) + except json.JSONDecodeError: + step.action_input = {"raw_input": input_text} + + return step + + async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: + """执行工具""" + tool = self.tools.get(tool_name) + + if not tool: + return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" + + try: + self._tool_calls += 1 + await self.emit_tool_call(tool_name, tool_input) + + import time + start = time.time() + + result = await tool.execute(**tool_input) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + if result.success: + output = str(result.data) + if len(output) > 4000: + output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]" + return output + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"Tool execution error: {e}") + return f"工具执行错误: {str(e)}" async def run(self, input_data: Dict[str, Any]) -> AgentResult: - """执行信息收集""" + """ + 执行信息收集 - LLM 全程参与! + """ import time start_time = time.time() project_info = input_data.get("project_info", {}) config = input_data.get("config", {}) + task = input_data.get("task", "") + task_context = input_data.get("task_context", "") + + # 构建初始消息 + initial_message = f"""请开始收集项目信息。 + +## 项目基本信息 +- 名称: {project_info.get('name', 'unknown')} +- 根目录: {project_info.get('root', '.')} + +## 任务上下文 +{task_context or task or '进行全面的信息收集,为安全审计做准备。'} + +## 可用工具 +{self._get_tools_description()} + +请开始你的信息收集工作。首先思考应该收集什么信息,然后选择合适的工具。""" + + # 初始化对话历史 + self._conversation_history = [ + {"role": "system", "content": self.config.system_prompt}, + {"role": "user", "content": initial_message}, + ] + + self._steps = [] + final_result = None + + await self.emit_thinking("🔍 Recon Agent 启动,LLM 开始自主收集信息...") try: - await self.emit_thinking("开始信息收集...") - - # 收集结果 - result_data = { - "project_structure": {}, - "tech_stack": { - "languages": [], - "frameworks": [], - "databases": [], - }, - "entry_points": [], - "high_risk_areas": [], - "dependencies": {}, - "initial_findings": [], - } - - # 1. 分析项目结构 - await self.emit_thinking("分析项目结构...") - structure = await self._analyze_structure() - result_data["project_structure"] = structure - - # 2. 识别技术栈 - await self.emit_thinking("识别技术栈...") - tech_stack = await self._identify_tech_stack(structure) - result_data["tech_stack"] = tech_stack - - # 3. 扫描依赖漏洞 - await self.emit_thinking("扫描依赖漏洞...") - deps_result = await self._scan_dependencies(tech_stack) - result_data["dependencies"] = deps_result.get("dependencies", {}) - if deps_result.get("findings"): - result_data["initial_findings"].extend(deps_result["findings"]) - - # 4. 快速密钥扫描 - await self.emit_thinking("扫描密钥泄露...") - secrets_result = await self._scan_secrets() - if secrets_result.get("findings"): - result_data["initial_findings"].extend(secrets_result["findings"]) - - # 5. 识别入口点 - await self.emit_thinking("识别入口点...") - entry_points = await self._identify_entry_points(tech_stack) - result_data["entry_points"] = entry_points - - # 6. 识别高风险区域 - result_data["high_risk_areas"] = self._identify_high_risk_areas( - structure, tech_stack, entry_points - ) + for iteration in range(self.config.max_iterations): + if self.is_cancelled: + break + + self._iteration = iteration + 1 + + # 🔥 发射 LLM 开始思考事件 + await self.emit_llm_start(iteration + 1) + + # 🔥 调用 LLM 进行思考和决策 + response = await self.llm_service.chat_completion_raw( + messages=self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + + llm_output = response.get("content", "") + tokens_this_round = response.get("usage", {}).get("total_tokens", 0) + self._total_tokens += tokens_this_round + + # 解析 LLM 响应 + step = self._parse_llm_response(llm_output) + self._steps.append(step) + + # 🔥 发射 LLM 思考内容事件 - 展示 LLM 在想什么 + if step.thought: + await self.emit_llm_thought(step.thought, iteration + 1) + + # 添加 LLM 响应到历史 + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + + # 检查是否完成 + if step.is_final: + await self.emit_llm_decision("完成信息收集", "LLM 判断已收集足够信息") + await self.emit_llm_complete( + f"信息收集完成,共 {self._iteration} 轮思考", + self._total_tokens + ) + final_result = step.final_answer + break + + # 执行工具 + if step.action: + # 🔥 发射 LLM 动作决策事件 + await self.emit_llm_action(step.action, step.action_input or {}) + + observation = await self._execute_tool( + step.action, + step.action_input or {} + ) + + step.observation = observation + + # 🔥 发射 LLM 观察事件 + await self.emit_llm_observation(observation) + + # 添加观察结果到历史 + self._conversation_history.append({ + "role": "user", + "content": f"Observation:\n{observation}", + }) + else: + # LLM 没有选择工具,提示它继续 + await self.emit_llm_decision("继续思考", "LLM 需要更多信息") + self._conversation_history.append({ + "role": "user", + "content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。", + }) + # 处理结果 duration_ms = int((time.time() - start_time) * 1000) + # 如果没有最终结果,从历史中汇总 + if not final_result: + final_result = self._summarize_from_steps() + await self.emit_event( "info", - f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, " - f"{len(result_data['high_risk_areas'])} 个高风险区域, " - f"{len(result_data['initial_findings'])} 个初步发现" + f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" ) return AgentResult( success=True, - data=result_data, + data=final_result, iterations=self._iteration, tool_calls=self._tool_calls, tokens_used=self._total_tokens, @@ -166,270 +356,58 @@ class ReconAgent(BaseAgent): ) except Exception as e: - logger.error(f"Recon agent failed: {e}", exc_info=True) + logger.error(f"Recon Agent failed: {e}", exc_info=True) return AgentResult(success=False, error=str(e)) - async def _analyze_structure(self) -> Dict[str, Any]: - """分析项目结构""" - structure = { - "directories": [], - "files_by_type": {}, - "config_files": [], - "total_files": 0, + def _summarize_from_steps(self) -> Dict[str, Any]: + """从步骤中汇总结果""" + # 默认结果结构 + result = { + "project_structure": {}, + "tech_stack": { + "languages": [], + "frameworks": [], + "databases": [], + }, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {}, + "initial_findings": [], } - # 列出根目录 - list_tool = self.tools.get("list_files") - if not list_tool: - return structure - - result = await list_tool.execute(directory=".", recursive=True, max_files=300) - - if result.success: - structure["total_files"] = result.metadata.get("file_count", 0) - - # 识别配置文件 - config_patterns = [ - "package.json", "requirements.txt", "go.mod", "Cargo.toml", - "pom.xml", "build.gradle", ".env", "config.py", "settings.py", - "docker-compose.yml", "Dockerfile", - ] - - # 从输出中解析文件列表 - if isinstance(result.data, str): - for line in result.data.split('\n'): - line = line.strip() - for pattern in config_patterns: - if pattern in line: - structure["config_files"].append(line) - - return structure - - async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]: - """识别技术栈""" - tech_stack = { - "languages": [], - "frameworks": [], - "databases": [], - "package_managers": [], - } - - config_files = structure.get("config_files", []) - - # 基于配置文件推断 - for cfg in config_files: - if "package.json" in cfg: - tech_stack["languages"].append("JavaScript/TypeScript") - tech_stack["package_managers"].append("npm") - elif "requirements.txt" in cfg or "setup.py" in cfg: - tech_stack["languages"].append("Python") - tech_stack["package_managers"].append("pip") - elif "go.mod" in cfg: - tech_stack["languages"].append("Go") - elif "Cargo.toml" in cfg: - tech_stack["languages"].append("Rust") - elif "pom.xml" in cfg or "build.gradle" in cfg: - tech_stack["languages"].append("Java") - - # 读取 package.json 识别框架 - read_tool = self.tools.get("read_file") - if read_tool and "package.json" in str(config_files): - result = await read_tool.execute(file_path="package.json", max_lines=100) - if result.success: - content = result.data - if "react" in content.lower(): - tech_stack["frameworks"].append("React") - if "vue" in content.lower(): - tech_stack["frameworks"].append("Vue") - if "express" in content.lower(): - tech_stack["frameworks"].append("Express") - if "fastify" in content.lower(): - tech_stack["frameworks"].append("Fastify") - if "next" in content.lower(): - tech_stack["frameworks"].append("Next.js") - - # 读取 requirements.txt 识别框架 - if read_tool and "requirements.txt" in str(config_files): - result = await read_tool.execute(file_path="requirements.txt", max_lines=50) - if result.success: - content = result.data.lower() - if "django" in content: - tech_stack["frameworks"].append("Django") - if "flask" in content: - tech_stack["frameworks"].append("Flask") - if "fastapi" in content: - tech_stack["frameworks"].append("FastAPI") - if "sqlalchemy" in content: - tech_stack["databases"].append("SQLAlchemy") - if "pymongo" in content: - tech_stack["databases"].append("MongoDB") + # 从步骤的观察结果中提取信息 + for step in self._steps: + if step.observation: + # 尝试从观察中识别技术栈等信息 + obs_lower = step.observation.lower() + + if "package.json" in obs_lower: + result["tech_stack"]["languages"].append("JavaScript/TypeScript") + if "requirements.txt" in obs_lower or "setup.py" in obs_lower: + result["tech_stack"]["languages"].append("Python") + if "go.mod" in obs_lower: + result["tech_stack"]["languages"].append("Go") + + # 识别框架 + if "react" in obs_lower: + result["tech_stack"]["frameworks"].append("React") + if "django" in obs_lower: + result["tech_stack"]["frameworks"].append("Django") + if "fastapi" in obs_lower: + result["tech_stack"]["frameworks"].append("FastAPI") + if "express" in obs_lower: + result["tech_stack"]["frameworks"].append("Express") # 去重 - tech_stack["languages"] = list(set(tech_stack["languages"])) - tech_stack["frameworks"] = list(set(tech_stack["frameworks"])) - tech_stack["databases"] = list(set(tech_stack["databases"])) - - return tech_stack - - async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]: - """扫描依赖漏洞""" - result = { - "dependencies": {}, - "findings": [], - } - - # npm audit - if "npm" in tech_stack.get("package_managers", []): - npm_tool = self.tools.get("npm_audit") - if npm_tool: - npm_result = await npm_tool.execute() - if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0: - result["dependencies"]["npm"] = npm_result.metadata - - # 转换为发现格式 - for sev, count in npm_result.metadata.get("severity_counts", {}).items(): - if count > 0 and sev in ["critical", "high"]: - result["findings"].append({ - "vulnerability_type": "dependency_vulnerability", - "severity": sev, - "title": f"npm 依赖漏洞 ({count} 个 {sev})", - "source": "npm_audit", - }) - - # Safety (Python) - if "pip" in tech_stack.get("package_managers", []): - safety_tool = self.tools.get("safety_scan") - if safety_tool: - safety_result = await safety_tool.execute() - if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0: - result["dependencies"]["pip"] = safety_result.metadata - result["findings"].append({ - "vulnerability_type": "dependency_vulnerability", - "severity": "high", - "title": f"Python 依赖漏洞", - "source": "safety", - }) - - # OSV Scanner - osv_tool = self.tools.get("osv_scan") - if osv_tool: - osv_result = await osv_tool.execute() - if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0: - result["dependencies"]["osv"] = osv_result.metadata + result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"])) + result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"])) return result - async def _scan_secrets(self) -> Dict[str, Any]: - """扫描密钥泄露""" - result = {"findings": []} - - gitleaks_tool = self.tools.get("gitleaks_scan") - if gitleaks_tool: - gl_result = await gitleaks_tool.execute() - if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0: - for finding in gl_result.metadata.get("findings", []): - result["findings"].append({ - "vulnerability_type": "hardcoded_secret", - "severity": "high", - "title": f"密钥泄露: {finding.get('rule', 'unknown')}", - "file_path": finding.get("file"), - "line_start": finding.get("line"), - "source": "gitleaks", - }) - - return result + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return self._conversation_history - async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]: - """识别入口点""" - entry_points = [] - search_tool = self.tools.get("search_code") - - if not search_tool: - return entry_points - - # 基于框架搜索入口点 - search_patterns = [] - - frameworks = tech_stack.get("frameworks", []) - - if "Express" in frameworks: - search_patterns.extend([ - ("app.get(", "Express GET route"), - ("app.post(", "Express POST route"), - ("router.get(", "Express router GET"), - ("router.post(", "Express router POST"), - ]) - - if "FastAPI" in frameworks: - search_patterns.extend([ - ("@app.get(", "FastAPI GET endpoint"), - ("@app.post(", "FastAPI POST endpoint"), - ("@router.get(", "FastAPI router GET"), - ("@router.post(", "FastAPI router POST"), - ]) - - if "Django" in frameworks: - search_patterns.extend([ - ("def get(self", "Django GET view"), - ("def post(self", "Django POST view"), - ("path(", "Django URL pattern"), - ]) - - if "Flask" in frameworks: - search_patterns.extend([ - ("@app.route(", "Flask route"), - ("@blueprint.route(", "Flask blueprint route"), - ]) - - # 通用模式 - search_patterns.extend([ - ("def handle", "Handler function"), - ("async def handle", "Async handler"), - ("class.*Controller", "Controller class"), - ("class.*Handler", "Handler class"), - ]) - - for pattern, description in search_patterns[:10]: # 限制搜索数量 - result = await search_tool.execute(keyword=pattern, max_results=10) - if result.success and result.metadata.get("matches", 0) > 0: - for match in result.metadata.get("results", [])[:5]: - entry_points.append({ - "type": description, - "file": match.get("file"), - "line": match.get("line"), - "pattern": pattern, - }) - - return entry_points[:30] # 限制总数 - - def _identify_high_risk_areas( - self, - structure: Dict, - tech_stack: Dict, - entry_points: List[Dict], - ) -> List[str]: - """识别高风险区域""" - high_risk = set() - - # 通用高风险目录 - risk_dirs = [ - "auth/", "authentication/", "login/", - "api/", "routes/", "controllers/", "handlers/", - "db/", "database/", "models/", - "admin/", "management/", - "upload/", "file/", - "payment/", "billing/", - ] - - for dir_name in risk_dirs: - high_risk.add(dir_name) - - # 从入口点提取目录 - for ep in entry_points: - file_path = ep.get("file", "") - if "/" in file_path: - dir_path = "/".join(file_path.split("/")[:-1]) + "/" - high_risk.add(dir_path) - - return list(high_risk)[:20] - + def get_steps(self) -> List[ReconStep]: + """获取执行步骤""" + return self._steps diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py index 9250287..e4c17bb 100644 --- a/backend/app/services/agent/agents/verification.py +++ b/backend/app/services/agent/agents/verification.py @@ -1,13 +1,20 @@ """ -Verification Agent (漏洞验证层) -负责漏洞确认、PoC 生成、沙箱测试 +Verification Agent (漏洞验证层) - LLM 驱动版 -类型: ReAct +LLM 是验证的大脑! +- LLM 决定如何验证每个漏洞 +- LLM 构造验证策略 +- LLM 分析验证结果 +- LLM 判断是否为真实漏洞 + +类型: ReAct (真正的!) """ -import asyncio +import json import logging +import re from typing import List, Dict, Any, Optional +from dataclasses import dataclass from datetime import datetime, timezone from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern @@ -15,69 +22,121 @@ from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern logger = logging.getLogger(__name__) -VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,负责确认发现的漏洞是否真实存在。 +VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,一个**自主**的安全验证专家。 -## 你的职责 -1. 分析漏洞上下文,判断是否为真正的安全问题 -2. 构造 PoC(概念验证)代码 -3. 在沙箱中执行测试 -4. 评估漏洞的实际影响 +## 你的角色 +你是漏洞验证的**大脑**,不是机械验证器。你需要: +1. 理解每个漏洞的上下文 +2. 设计合适的验证策略 +3. 使用工具获取更多信息 +4. 判断漏洞是否真实存在 +5. 评估实际影响 ## 你可以使用的工具 + ### 代码分析 -- read_file: 读取更多上下文 -- function_context: 分析函数调用关系 -- dataflow_analysis: 追踪数据流 -- vulnerability_validation: LLM 漏洞验证 +- **read_file**: 读取更多代码上下文 + 参数: file_path (str), start_line (int), end_line (int) +- **function_context**: 分析函数调用关系 + 参数: function_name (str) +- **dataflow_analysis**: 追踪数据流 + 参数: source (str), sink (str), file_path (str) +- **vulnerability_validation**: LLM 深度验证 ⭐ + 参数: code (str), vulnerability_type (str), context (str) -### 沙箱执行 -- sandbox_exec: 在沙箱中执行命令 -- sandbox_http: 发送 HTTP 请求 -- verify_vulnerability: 自动验证漏洞 +### 沙箱验证 +- **sandbox_exec**: 在沙箱中执行命令 + 参数: command (str), timeout (int) +- **sandbox_http**: 发送 HTTP 请求测试 + 参数: method (str), url (str), data (dict), headers (dict) +- **verify_vulnerability**: 自动化漏洞验证 + 参数: vulnerability_type (str), target (str), payload (str) -## 验证流程 -1. **上下文分析**: 获取更多代码上下文 -2. **可利用性分析**: 判断漏洞是否可被利用 -3. **PoC 构造**: 设计验证方案 -4. **沙箱测试**: 在隔离环境中测试 -5. **结果评估**: 确定漏洞是否真实存在 +## 工作方式 +你将收到一批待验证的漏洞发现。对于每个发现,你需要: -## 验证标准 -- **确认 (confirmed)**: 漏洞真实存在且可利用 -- **可能 (likely)**: 高度可能存在漏洞 -- **不确定 (uncertain)**: 需要更多信息 -- **误报 (false_positive)**: 确认是误报 +``` +Thought: [分析这个漏洞,思考如何验证] +Action: [工具名称] +Action Input: [JSON 格式的参数] +``` -## 输出格式 +验证完所有发现后,输出: + +``` +Thought: [总结验证结果] +Final Answer: [JSON 格式的验证报告] +``` + +## Final Answer 格式 ```json { "findings": [ { - "original_finding": {...}, + ...原始发现字段..., "verdict": "confirmed/likely/uncertain/false_positive", "confidence": 0.0-1.0, "is_verified": true/false, "verification_method": "描述验证方法", + "verification_details": "验证过程和结果详情", "poc": { - "code": "PoC 代码", - "description": "描述", - "steps": ["步骤1", "步骤2"] + "description": "PoC 描述", + "steps": ["步骤1", "步骤2"], + "payload": "测试 payload" }, - "impact": "影响分析", + "impact": "实际影响分析", "recommendation": "修复建议" } - ] + ], + "summary": { + "total": 数量, + "confirmed": 数量, + "likely": 数量, + "false_positive": 数量 + } } ``` -请谨慎验证,减少误报,同时不遗漏真正的漏洞。""" +## 验证判定标准 +- **confirmed**: 漏洞确认存在且可利用,有明确证据 +- **likely**: 高度可能存在漏洞,但无法完全确认 +- **uncertain**: 需要更多信息才能判断 +- **false_positive**: 确认是误报,有明确理由 + +## 验证策略建议 +1. **上下文分析**: 用 read_file 获取更多代码上下文 +2. **数据流追踪**: 用 dataflow_analysis 确认污点传播 +3. **LLM 深度分析**: 用 vulnerability_validation 进行专业分析 +4. **沙箱测试**: 对高危漏洞用沙箱进行安全测试 + +## 重要原则 +1. **质量优先** - 宁可漏报也不要误报太多 +2. **深入理解** - 理解代码逻辑,不要表面判断 +3. **证据支撑** - 判定要有依据 +4. **安全第一** - 沙箱测试要谨慎 + +现在开始验证漏洞发现!""" + + +@dataclass +class VerificationStep: + """验证步骤""" + thought: str + action: Optional[str] = None + action_input: Optional[Dict] = None + observation: Optional[str] = None + is_final: bool = False + final_answer: Optional[Dict] = None class VerificationAgent(BaseAgent): """ - 漏洞验证 Agent + 漏洞验证 Agent - LLM 驱动版 - 使用 ReAct 模式验证发现的漏洞 + LLM 全程参与,自主决定: + 1. 如何验证每个漏洞 + 2. 使用什么工具 + 3. 判断真假 """ def __init__( @@ -90,25 +149,114 @@ class VerificationAgent(BaseAgent): name="Verification", agent_type=AgentType.VERIFICATION, pattern=AgentPattern.REACT, - max_iterations=20, + max_iterations=25, system_prompt=VERIFICATION_SYSTEM_PROMPT, - tools=[ - "read_file", "function_context", "dataflow_analysis", - "vulnerability_validation", - "sandbox_exec", "sandbox_http", "verify_vulnerability", - ], ) super().__init__(config, llm_service, tools, event_emitter) + + self._conversation_history: List[Dict[str, str]] = [] + self._steps: List[VerificationStep] = [] + + def _get_tools_description(self) -> str: + """生成工具描述""" + tools_info = [] + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + desc = f"- {name}: {getattr(tool, 'description', 'No description')}" + tools_info.append(desc) + return "\n".join(tools_info) + + def _parse_llm_response(self, response: str) -> VerificationStep: + """解析 LLM 响应""" + step = VerificationStep(thought="") + + # 提取 Thought + thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL) + if thought_match: + step.thought = thought_match.group(1).strip() + + # 检查是否是最终答案 + final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) + if final_match: + step.is_final = True + try: + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + step.final_answer = json.loads(answer_text) + except json.JSONDecodeError: + step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()} + return step + + # 提取 Action + action_match = re.search(r'Action:\s*(\w+)', response) + if action_match: + step.action = action_match.group(1).strip() + + # 提取 Action Input + input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) + if input_match: + input_text = input_match.group(1).strip() + input_text = re.sub(r'```json\s*', '', input_text) + input_text = re.sub(r'```\s*', '', input_text) + try: + step.action_input = json.loads(input_text) + except json.JSONDecodeError: + step.action_input = {"raw_input": input_text} + + return step + + async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: + """执行工具""" + tool = self.tools.get(tool_name) + + if not tool: + return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" + + try: + self._tool_calls += 1 + await self.emit_tool_call(tool_name, tool_input) + + import time + start = time.time() + + result = await tool.execute(**tool_input) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + if result.success: + output = str(result.data) + + # 包含 metadata + if result.metadata: + if "validation" in result.metadata: + output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}" + + if len(output) > 4000: + output = output[:4000] + f"\n\n... [输出已截断]" + return output + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"Tool execution error: {e}") + return f"工具执行错误: {str(e)}" async def run(self, input_data: Dict[str, Any]) -> AgentResult: - """执行漏洞验证""" + """ + 执行漏洞验证 - LLM 全程参与! + """ import time start_time = time.time() previous_results = input_data.get("previous_results", {}) config = input_data.get("config", {}) + task = input_data.get("task", "") + task_context = input_data.get("task_context", "") - # 收集所有需要验证的发现 + # 收集所有待验证的发现 findings_to_verify = [] for phase_name, result in previous_results.items(): @@ -133,52 +281,164 @@ class VerificationAgent(BaseAgent): data={"findings": [], "verified_count": 0}, ) + # 限制数量 + findings_to_verify = findings_to_verify[:20] + await self.emit_event( "info", f"开始验证 {len(findings_to_verify)} 个发现" ) + # 构建初始消息 + findings_summary = [] + for i, f in enumerate(findings_to_verify): + findings_summary.append(f""" +### 发现 {i+1}: {f.get('title', 'Unknown')} +- 类型: {f.get('vulnerability_type', 'unknown')} +- 严重度: {f.get('severity', 'medium')} +- 文件: {f.get('file_path', 'unknown')}:{f.get('line_start', 0)} +- 代码: +``` +{f.get('code_snippet', 'N/A')[:500]} +``` +- 描述: {f.get('description', 'N/A')[:300]} +""") + + initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。 + +## 待验证发现 +{''.join(findings_summary)} + +## 验证要求 +- 验证级别: {config.get('verification_level', 'standard')} + +## 可用工具 +{self._get_tools_description()} + +请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。""" + + # 初始化对话历史 + self._conversation_history = [ + {"role": "system", "content": self.config.system_prompt}, + {"role": "user", "content": initial_message}, + ] + + self._steps = [] + final_result = None + + await self.emit_thinking("🔐 Verification Agent 启动,LLM 开始自主验证漏洞...") + try: - verified_findings = [] - verification_level = config.get("verification_level", "sandbox") - - for i, finding in enumerate(findings_to_verify[:20]): # 限制数量 + for iteration in range(self.config.max_iterations): if self.is_cancelled: break - await self.emit_thinking( - f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}" + self._iteration = iteration + 1 + + # 🔥 发射 LLM 开始思考事件 + await self.emit_llm_start(iteration + 1) + + # 🔥 调用 LLM 进行思考和决策 + response = await self.llm_service.chat_completion_raw( + messages=self._conversation_history, + temperature=0.1, + max_tokens=3000, ) - # 执行验证 - verified = await self._verify_finding(finding, verification_level) - verified_findings.append(verified) + llm_output = response.get("content", "") + tokens_this_round = response.get("usage", {}).get("total_tokens", 0) + self._total_tokens += tokens_this_round - # 发射事件 - if verified.get("is_verified"): - await self.emit_event( - "finding_verified", - f"✅ 已确认: {verified.get('title', '')}", - finding_id=verified.get("id"), - metadata={"severity": verified.get("severity")} + # 解析 LLM 响应 + step = self._parse_llm_response(llm_output) + self._steps.append(step) + + # 🔥 发射 LLM 思考内容事件 - 展示验证的思考过程 + if step.thought: + await self.emit_llm_thought(step.thought, iteration + 1) + + # 添加 LLM 响应到历史 + self._conversation_history.append({ + "role": "assistant", + "content": llm_output, + }) + + # 检查是否完成 + if step.is_final: + await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分") + final_result = step.final_answer + await self.emit_llm_complete( + f"验证完成", + self._total_tokens ) - elif verified.get("verdict") == "false_positive": - await self.emit_event( - "finding_false_positive", - f"❌ 误报: {verified.get('title', '')}", - finding_id=verified.get("id"), + break + + # 执行工具 + if step.action: + # 🔥 发射 LLM 动作决策事件 + await self.emit_llm_action(step.action, step.action_input or {}) + + observation = await self._execute_tool( + step.action, + step.action_input or {} ) + + step.observation = observation + + # 🔥 发射 LLM 观察事件 + await self.emit_llm_observation(observation) + + # 添加观察结果到历史 + self._conversation_history.append({ + "role": "user", + "content": f"Observation:\n{observation}", + }) + else: + # LLM 没有选择工具,提示它继续 + await self.emit_llm_decision("继续验证", "LLM 需要更多验证") + self._conversation_history.append({ + "role": "user", + "content": "请继续验证。如果验证完成,输出 Final Answer 汇总所有验证结果。", + }) + + # 处理结果 + duration_ms = int((time.time() - start_time) * 1000) + + # 处理最终结果 + verified_findings = [] + if final_result and "findings" in final_result: + for f in final_result["findings"]: + verified = { + **f, + "is_verified": f.get("verdict") == "confirmed" or ( + f.get("verdict") == "likely" and f.get("confidence", 0) >= 0.8 + ), + "verified_at": datetime.now(timezone.utc).isoformat() if f.get("verdict") in ["confirmed", "likely"] else None, + } + + # 添加修复建议 + if not verified.get("recommendation"): + verified["recommendation"] = self._get_recommendation(f.get("vulnerability_type", "")) + + verified_findings.append(verified) + else: + # 如果没有最终结果,使用原始发现 + for f in findings_to_verify: + verified_findings.append({ + **f, + "verdict": "uncertain", + "confidence": 0.5, + "is_verified": False, + }) # 统计 - confirmed_count = len([f for f in verified_findings if f.get("is_verified")]) + confirmed_count = len([f for f in verified_findings if f.get("verdict") == "confirmed"]) likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"]) false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"]) - duration_ms = int((time.time() - start_time) * 1000) - await self.emit_event( "info", - f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报" + f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报" ) return AgentResult( @@ -196,167 +456,9 @@ class VerificationAgent(BaseAgent): ) except Exception as e: - logger.error(f"Verification agent failed: {e}", exc_info=True) + logger.error(f"Verification Agent failed: {e}", exc_info=True) return AgentResult(success=False, error=str(e)) - async def _verify_finding( - self, - finding: Dict[str, Any], - verification_level: str, - ) -> Dict[str, Any]: - """验证单个发现""" - result = { - **finding, - "verdict": "uncertain", - "confidence": 0.5, - "is_verified": False, - "verification_method": None, - "verified_at": None, - } - - vuln_type = finding.get("vulnerability_type", "") - file_path = finding.get("file_path", "") - line_start = finding.get("line_start", 0) - code_snippet = finding.get("code_snippet", "") - - try: - # 1. 获取更多上下文 - context = await self._get_context(file_path, line_start) - - # 2. LLM 验证 - validation_result = await self._llm_validation( - finding, context - ) - - result["verdict"] = validation_result.get("verdict", "uncertain") - result["confidence"] = validation_result.get("confidence", 0.5) - result["verification_method"] = "llm_analysis" - - # 3. 如果需要沙箱验证 - if verification_level in ["sandbox", "generate_poc"]: - if result["verdict"] in ["confirmed", "likely"]: - if vuln_type in ["sql_injection", "command_injection", "xss"]: - sandbox_result = await self._sandbox_verification( - finding, validation_result - ) - - if sandbox_result.get("verified"): - result["verdict"] = "confirmed" - result["confidence"] = max(result["confidence"], 0.9) - result["verification_method"] = "sandbox_test" - result["poc"] = sandbox_result.get("poc") - - # 4. 判断是否已验证 - if result["verdict"] == "confirmed" or ( - result["verdict"] == "likely" and result["confidence"] >= 0.8 - ): - result["is_verified"] = True - result["verified_at"] = datetime.now(timezone.utc).isoformat() - - # 5. 添加修复建议 - if result["is_verified"]: - result["recommendation"] = self._get_recommendation(vuln_type) - - except Exception as e: - logger.warning(f"Verification failed for {file_path}: {e}") - result["error"] = str(e) - - return result - - async def _get_context(self, file_path: str, line_start: int) -> str: - """获取代码上下文""" - read_tool = self.tools.get("read_file") - if not read_tool or not file_path: - return "" - - result = await read_tool.execute( - file_path=file_path, - start_line=max(1, line_start - 30), - end_line=line_start + 30, - ) - - return result.data if result.success else "" - - async def _llm_validation( - self, - finding: Dict[str, Any], - context: str, - ) -> Dict[str, Any]: - """LLM 漏洞验证""" - validation_tool = self.tools.get("vulnerability_validation") - - if not validation_tool: - return {"verdict": "uncertain", "confidence": 0.5} - - code = finding.get("code_snippet", "") or context[:2000] - - result = await validation_tool.execute( - code=code, - vulnerability_type=finding.get("vulnerability_type", "unknown"), - file_path=finding.get("file_path", ""), - line_number=finding.get("line_start"), - context=context[:1000] if context else None, - ) - - if result.success and result.metadata.get("validation"): - validation = result.metadata["validation"] - - verdict_map = { - "confirmed": "confirmed", - "likely": "likely", - "unlikely": "uncertain", - "false_positive": "false_positive", - } - - return { - "verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"), - "confidence": validation.get("confidence", 0.5), - "explanation": validation.get("detailed_analysis", ""), - "exploitation_conditions": validation.get("exploitation_conditions", []), - "poc_idea": validation.get("poc_idea"), - } - - return {"verdict": "uncertain", "confidence": 0.5} - - async def _sandbox_verification( - self, - finding: Dict[str, Any], - validation_result: Dict[str, Any], - ) -> Dict[str, Any]: - """沙箱验证""" - result = {"verified": False, "poc": None} - - vuln_type = finding.get("vulnerability_type", "") - poc_idea = validation_result.get("poc_idea", "") - - # 根据漏洞类型选择验证方法 - sandbox_tool = self.tools.get("sandbox_exec") - http_tool = self.tools.get("sandbox_http") - verify_tool = self.tools.get("verify_vulnerability") - - if vuln_type == "command_injection" and sandbox_tool: - # 构造安全的测试命令 - test_cmd = "echo 'test_marker_12345'" - - exec_result = await sandbox_tool.execute( - command=f"python3 -c \"print('test')\"", - timeout=10, - ) - - if exec_result.success: - result["verified"] = True - result["poc"] = { - "description": "命令注入测试", - "method": "sandbox_exec", - } - - elif vuln_type in ["sql_injection", "xss"] and verify_tool: - # 使用自动验证工具 - # 注意:这需要实际的目标 URL - pass - - return result - def _get_recommendation(self, vuln_type: str) -> str: """获取修复建议""" recommendations = { @@ -369,7 +471,6 @@ class VerificationAgent(BaseAgent): "hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息", "weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免 MD5/SHA1", } - return recommendations.get(vuln_type, "请根据具体情况修复此安全问题") def _deduplicate(self, findings: List[Dict]) -> List[Dict]: @@ -389,4 +490,11 @@ class VerificationAgent(BaseAgent): unique.append(f) return unique - + + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return self._conversation_history + + def get_steps(self) -> List[VerificationStep]: + """获取执行步骤""" + return self._steps diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py index 77898ef..b1ead15 100644 --- a/backend/app/services/agent/event_manager.py +++ b/backend/app/services/agent/event_manager.py @@ -91,6 +91,33 @@ class AgentEventEmitter: metadata=metadata, )) + async def emit_llm_thought(self, thought: str, iteration: int = 0): + """发射 LLM 思考内容事件 - 核心!展示 LLM 在想什么""" + display = thought[:500] + "..." if len(thought) > 500 else thought + await self.emit(AgentEventData( + event_type="llm_thought", + message=f"💭 LLM 思考:\n{display}", + metadata={"thought": thought, "iteration": iteration}, + )) + + async def emit_llm_decision(self, decision: str, reason: str = ""): + """发射 LLM 决策事件""" + await self.emit(AgentEventData( + event_type="llm_decision", + message=f"💡 LLM 决策: {decision}" + (f" ({reason})" if reason else ""), + metadata={"decision": decision, "reason": reason}, + )) + + async def emit_llm_action(self, action: str, action_input: Dict): + """发射 LLM 动作事件""" + import json + input_str = json.dumps(action_input, ensure_ascii=False)[:200] + await self.emit(AgentEventData( + event_type="llm_action", + message=f"⚡ LLM 动作: {action}\n 参数: {input_str}", + metadata={"action": action, "action_input": action_input}, + )) + async def emit_tool_call( self, tool_name: str, diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py index d53a577..8f62338 100644 --- a/backend/app/services/agent/graph/audit_graph.py +++ b/backend/app/services/agent/graph/audit_graph.py @@ -1,12 +1,15 @@ """ -DeepAudit 审计工作流图 -使用 LangGraph 构建状态机式的 Agent 协作流程 +DeepAudit 审计工作流图 - LLM 驱动版 +使用 LangGraph 构建 LLM 驱动的 Agent 协作流程 + +重要改变:路由决策由 LLM 参与,而不是硬编码条件! """ from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal from datetime import datetime import operator import logging +import json from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver @@ -56,12 +59,16 @@ class AuditState(TypedDict): verified_findings: List[Finding] false_positives: List[str] - # 控制流 + # 控制流 - 🔥 关键:LLM 可以设置这些来影响路由 current_phase: str iteration: int max_iterations: int should_continue_analysis: bool + # 🔥 新增:LLM 的路由决策 + llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end" + llm_routing_reason: Optional[str] # LLM 的决策理由 + # 消息和事件 messages: Annotated[List[Dict], operator.add] events: Annotated[List[Dict], operator.add] @@ -72,43 +79,233 @@ class AuditState(TypedDict): error: Optional[str] -# ============ 路由函数 ============ +# ============ LLM 路由决策器 ============ + +class LLMRouter: + """ + LLM 路由决策器 + 让 LLM 来决定下一步应该做什么 + """ + + def __init__(self, llm_service): + self.llm_service = llm_service + + async def decide_after_recon(self, state: AuditState) -> Dict[str, Any]: + """Recon 后让 LLM 决定下一步""" + entry_points = state.get("entry_points", []) + high_risk_areas = state.get("high_risk_areas", []) + tech_stack = state.get("tech_stack", {}) + initial_findings = state.get("findings", []) + + prompt = f"""作为安全审计的决策者,基于以下信息收集结果,决定下一步行动。 + +## 信息收集结果 +- 入口点数量: {len(entry_points)} +- 高风险区域: {high_risk_areas[:10]} +- 技术栈: {tech_stack} +- 初步发现: {len(initial_findings)} 个 + +## 选项 +1. "analysis" - 继续进行漏洞分析(推荐:有入口点或高风险区域时) +2. "end" - 结束审计(仅当没有任何可分析内容时) + +请返回 JSON 格式: +{{"action": "analysis或end", "reason": "决策理由"}}""" + + try: + response = await self.llm_service.chat_completion_raw( + messages=[ + {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=200, + ) + + content = response.get("content", "") + # 提取 JSON + import re + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + result = json.loads(json_match.group()) + return result + except Exception as e: + logger.warning(f"LLM routing decision failed: {e}") + + # 默认决策 + if entry_points or high_risk_areas: + return {"action": "analysis", "reason": "有可分析内容"} + return {"action": "end", "reason": "没有发现入口点或高风险区域"} + + async def decide_after_analysis(self, state: AuditState) -> Dict[str, Any]: + """Analysis 后让 LLM 决定下一步""" + findings = state.get("findings", []) + iteration = state.get("iteration", 0) + max_iterations = state.get("max_iterations", 3) + + # 统计发现 + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} + for f in findings: + sev = f.get("severity", "medium") + severity_counts[sev] = severity_counts.get(sev, 0) + 1 + + prompt = f"""作为安全审计的决策者,基于以下分析结果,决定下一步行动。 + +## 分析结果 +- 总发现数: {len(findings)} +- 严重程度分布: {severity_counts} +- 当前迭代: {iteration}/{max_iterations} + +## 选项 +1. "verification" - 验证发现的漏洞(推荐:有发现需要验证时) +2. "analysis" - 继续深入分析(推荐:发现较少但还有迭代次数时) +3. "report" - 生成报告(推荐:没有发现或已充分分析时) + +请返回 JSON 格式: +{{"action": "verification/analysis/report", "reason": "决策理由"}}""" + + try: + response = await self.llm_service.chat_completion_raw( + messages=[ + {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=200, + ) + + content = response.get("content", "") + import re + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + result = json.loads(json_match.group()) + return result + except Exception as e: + logger.warning(f"LLM routing decision failed: {e}") + + # 默认决策 + if not findings: + return {"action": "report", "reason": "没有发现漏洞"} + if len(findings) >= 3 or iteration >= max_iterations: + return {"action": "verification", "reason": "有足够的发现需要验证"} + return {"action": "analysis", "reason": "发现较少,继续分析"} + + async def decide_after_verification(self, state: AuditState) -> Dict[str, Any]: + """Verification 后让 LLM 决定下一步""" + verified_findings = state.get("verified_findings", []) + false_positives = state.get("false_positives", []) + iteration = state.get("iteration", 0) + max_iterations = state.get("max_iterations", 3) + + prompt = f"""作为安全审计的决策者,基于以下验证结果,决定下一步行动。 + +## 验证结果 +- 已确认漏洞: {len(verified_findings)} +- 误报数量: {len(false_positives)} +- 当前迭代: {iteration}/{max_iterations} + +## 选项 +1. "analysis" - 回到分析阶段重新分析(推荐:误报率太高时) +2. "report" - 生成最终报告(推荐:验证完成时) + +请返回 JSON 格式: +{{"action": "analysis/report", "reason": "决策理由"}}""" + + try: + response = await self.llm_service.chat_completion_raw( + messages=[ + {"role": "system", "content": "你是安全审计流程的决策者,负责决定下一步行动。"}, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=200, + ) + + content = response.get("content", "") + import re + json_match = re.search(r'\{.*\}', content, re.DOTALL) + if json_match: + result = json.loads(json_match.group()) + return result + except Exception as e: + logger.warning(f"LLM routing decision failed: {e}") + + # 默认决策 + if len(false_positives) > len(verified_findings) and iteration < max_iterations: + return {"action": "analysis", "reason": "误报率较高,需要重新分析"} + return {"action": "report", "reason": "验证完成,生成报告"} + + +# ============ 路由函数 (结合 LLM 决策) ============ def route_after_recon(state: AuditState) -> Literal["analysis", "end"]: - """Recon 后的路由决策""" - # 如果没有发现入口点或高风险区域,直接结束 + """ + Recon 后的路由决策 + 优先使用 LLM 的决策,否则使用默认逻辑 + """ + # 检查 LLM 是否有决策 + llm_action = state.get("llm_next_action") + if llm_action: + logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") + if llm_action == "end": + return "end" + return "analysis" + + # 默认逻辑(作为 fallback) if not state.get("entry_points") and not state.get("high_risk_areas"): return "end" return "analysis" def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]: - """Analysis 后的路由决策""" + """ + Analysis 后的路由决策 + 优先使用 LLM 的决策 + """ + # 检查 LLM 是否有决策 + llm_action = state.get("llm_next_action") + if llm_action: + logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") + if llm_action == "verification": + return "verification" + elif llm_action == "analysis": + return "analysis" + elif llm_action == "report": + return "report" + + # 默认逻辑 findings = state.get("findings", []) iteration = state.get("iteration", 0) max_iterations = state.get("max_iterations", 3) should_continue = state.get("should_continue_analysis", False) - # 如果没有发现,直接生成报告 if not findings: return "report" - # 如果需要继续分析且未达到最大迭代 if should_continue and iteration < max_iterations: return "analysis" - # 有发现需要验证 return "verification" def route_after_verification(state: AuditState) -> Literal["analysis", "report"]: - """Verification 后的路由决策""" - # 如果验证发现了误报,可能需要重新分析 + """ + Verification 后的路由决策 + 优先使用 LLM 的决策 + """ + # 检查 LLM 是否有决策 + llm_action = state.get("llm_next_action") + if llm_action: + logger.info(f"Using LLM routing decision: {llm_action}, reason: {state.get('llm_routing_reason')}") + if llm_action == "analysis": + return "analysis" + return "report" + + # 默认逻辑 false_positives = state.get("false_positives", []) iteration = state.get("iteration", 0) max_iterations = state.get("max_iterations", 3) - # 如果误报率太高且还有迭代次数,回到分析 if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations: return "analysis" @@ -123,6 +320,7 @@ def create_audit_graph( verification_node, report_node, checkpointer: Optional[MemorySaver] = None, + llm_service=None, # 用于 LLM 路由决策 ) -> StateGraph: """ 创建审计工作流图 @@ -132,7 +330,8 @@ def create_audit_graph( analysis_node: 漏洞分析节点 verification_node: 漏洞验证节点 report_node: 报告生成节点 - checkpointer: 检查点存储器(用于状态持久化) + checkpointer: 检查点存储器 + llm_service: LLM 服务(用于路由决策) Returns: 编译后的 StateGraph @@ -143,19 +342,19 @@ def create_audit_graph( │ ▼ ┌──────┐ - │Recon │ 信息收集 + │Recon │ 信息收集 (LLM 驱动) └──┬───┘ - │ + │ LLM 决定 ▼ ┌──────────┐ - │ Analysis │◄─────┐ 漏洞分析(可循环) + │ Analysis │◄─────┐ 漏洞分析 (LLM 驱动,可循环) └────┬─────┘ │ - │ │ + │ LLM 决定 │ ▼ │ ┌────────────┐ │ - │Verification│────┘ 漏洞验证(可回溯) + │Verification│────┘ 漏洞验证 (LLM 驱动,可回溯) └─────┬──────┘ - │ + │ LLM 决定 ▼ ┌──────────┐ │ Report │ 报告生成 @@ -168,10 +367,53 @@ def create_audit_graph( # 创建状态图 workflow = StateGraph(AuditState) + # 如果有 LLM 服务,创建路由决策器 + llm_router = LLMRouter(llm_service) if llm_service else None + + # 包装节点以添加 LLM 路由决策 + async def recon_with_routing(state): + result = await recon_node(state) + + # LLM 决定下一步 + if llm_router: + decision = await llm_router.decide_after_recon({**state, **result}) + result["llm_next_action"] = decision.get("action") + result["llm_routing_reason"] = decision.get("reason") + + return result + + async def analysis_with_routing(state): + result = await analysis_node(state) + + # LLM 决定下一步 + if llm_router: + decision = await llm_router.decide_after_analysis({**state, **result}) + result["llm_next_action"] = decision.get("action") + result["llm_routing_reason"] = decision.get("reason") + + return result + + async def verification_with_routing(state): + result = await verification_node(state) + + # LLM 决定下一步 + if llm_router: + decision = await llm_router.decide_after_verification({**state, **result}) + result["llm_next_action"] = decision.get("action") + result["llm_routing_reason"] = decision.get("reason") + + return result + # 添加节点 - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) - workflow.add_node("verification", verification_node) + if llm_router: + workflow.add_node("recon", recon_with_routing) + workflow.add_node("analysis", analysis_with_routing) + workflow.add_node("verification", verification_with_routing) + else: + workflow.add_node("recon", recon_node) + workflow.add_node("analysis", analysis_node) + workflow.add_node("verification", verification_node) + workflow.add_node("report", report_node) # 设置入口点 @@ -192,7 +434,7 @@ def create_audit_graph( route_after_analysis, { "verification": "verification", - "analysis": "analysis", # 循环 + "analysis": "analysis", "report": "report", } ) @@ -201,7 +443,7 @@ def create_audit_graph( "verification", route_after_verification, { - "analysis": "analysis", # 回溯 + "analysis": "analysis", "report": "report", } ) @@ -225,50 +467,42 @@ def create_audit_graph_with_human( report_node, human_review_node, checkpointer: Optional[MemorySaver] = None, + llm_service=None, ) -> StateGraph: """ 创建带人机协作的审计工作流图 在验证阶段后增加人工审核节点 - - 工作流结构: - - START - │ - ▼ - ┌──────┐ - │Recon │ - └──┬───┘ - │ - ▼ - ┌──────────┐ - │ Analysis │◄─────┐ - └────┬─────┘ │ - │ │ - ▼ │ - ┌────────────┐ │ - │Verification│────┘ - └─────┬──────┘ - │ - ▼ - ┌─────────────┐ - │Human Review │ ← 人工审核(可跳过) - └──────┬──────┘ - │ - ▼ - ┌──────────┐ - │ Report │ - └────┬─────┘ - │ - ▼ - END """ workflow = StateGraph(AuditState) + llm_router = LLMRouter(llm_service) if llm_service else None + + # 包装节点 + async def recon_with_routing(state): + result = await recon_node(state) + if llm_router: + decision = await llm_router.decide_after_recon({**state, **result}) + result["llm_next_action"] = decision.get("action") + result["llm_routing_reason"] = decision.get("reason") + return result + + async def analysis_with_routing(state): + result = await analysis_node(state) + if llm_router: + decision = await llm_router.decide_after_analysis({**state, **result}) + result["llm_next_action"] = decision.get("action") + result["llm_routing_reason"] = decision.get("reason") + return result # 添加节点 - workflow.add_node("recon", recon_node) - workflow.add_node("analysis", analysis_node) + if llm_router: + workflow.add_node("recon", recon_with_routing) + workflow.add_node("analysis", analysis_with_routing) + else: + workflow.add_node("recon", recon_node) + workflow.add_node("analysis", analysis_node) + workflow.add_node("verification", verification_node) workflow.add_node("human_review", human_review_node) workflow.add_node("report", report_node) @@ -296,7 +530,6 @@ def create_audit_graph_with_human( # Human Review 后的路由 def route_after_human(state: AuditState) -> Literal["analysis", "report"]: - # 人工可以决定重新分析或继续 if state.get("should_continue_analysis"): return "analysis" return "report" @@ -340,15 +573,6 @@ class AuditGraphRunner: ) -> Dict[str, Any]: """ 执行审计工作流 - - Args: - project_root: 项目根目录 - project_info: 项目信息 - config: 审计配置 - task_id: 任务 ID - - Returns: - 最终状态 """ # 初始状态 initial_state: AuditState = { @@ -367,6 +591,8 @@ class AuditGraphRunner: "iteration": 0, "max_iterations": config.get("max_iterations", 3), "should_continue_analysis": False, + "llm_next_action": None, + "llm_routing_reason": None, "messages": [], "events": [], "summary": None, @@ -374,32 +600,32 @@ class AuditGraphRunner: "error": None, } - # 配置 run_config = { "configurable": { "thread_id": task_id, } } - # 执行图 try: - # 流式执行 async for event in self.graph.astream(initial_state, config=run_config): - # 发射事件 if self.event_emitter: for node_name, node_state in event.items(): await self.event_emitter.emit_info( f"节点 {node_name} 完成" ) - # 发射发现事件 + # 发射 LLM 路由决策事件 + if node_state.get("llm_routing_reason"): + await self.event_emitter.emit_info( + f"🧠 LLM 决策: {node_state.get('llm_next_action')} - {node_state.get('llm_routing_reason')}" + ) + if node_name == "analysis" and node_state.get("findings"): new_findings = node_state["findings"] await self.event_emitter.emit_info( f"发现 {len(new_findings)} 个潜在漏洞" ) - # 获取最终状态 final_state = self.graph.get_state(run_config) return final_state.values @@ -412,44 +638,27 @@ class AuditGraphRunner: initial_state: AuditState, human_feedback_callback, ) -> Dict[str, Any]: - """ - 带人机协作的执行 - - Args: - initial_state: 初始状态 - human_feedback_callback: 人工反馈回调函数 - - Returns: - 最终状态 - """ + """带人机协作的执行""" run_config = { "configurable": { "thread_id": initial_state["task_id"], } } - # 执行到人工审核节点 async for event in self.graph.astream(initial_state, config=run_config): pass - # 获取当前状态 current_state = self.graph.get_state(run_config) - # 如果在人工审核节点暂停 if current_state.next == ("human_review",): - # 调用人工反馈 human_decision = await human_feedback_callback(current_state.values) - # 更新状态并继续 updated_state = { **current_state.values, "should_continue_analysis": human_decision.get("continue_analysis", False), } - # 继续执行 async for event in self.graph.astream(updated_state, config=run_config): pass - # 返回最终状态 return self.graph.get_state(run_config).values - diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index f8e3560..304121d 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -403,10 +403,19 @@ class AgentRunner: Returns: 最终状态 """ - result = {} - async for _ in self.run_with_streaming(): - pass # 消费所有事件 - return result + final_state = {} + try: + async for event in self.run_with_streaming(): + # 收集最终状态 + if event.event_type == StreamEventType.TASK_COMPLETE: + final_state = event.data + elif event.event_type == StreamEventType.TASK_ERROR: + final_state = {"success": False, "error": event.data.get("error")} + except Exception as e: + logger.error(f"Agent run failed: {e}", exc_info=True) + final_state = {"success": False, "error": str(e)} + + return final_state async def run_with_streaming(self) -> AsyncGenerator[StreamEvent, None]: """ diff --git a/backend/app/services/agent/streaming/stream_handler.py b/backend/app/services/agent/streaming/stream_handler.py index 7accd40..2c818fb 100644 --- a/backend/app/services/agent/streaming/stream_handler.py +++ b/backend/app/services/agent/streaming/stream_handler.py @@ -15,12 +15,20 @@ logger = logging.getLogger(__name__) class StreamEventType(str, Enum): """流式事件类型""" - # LLM 相关 + # 🔥 LLM 思考相关 - 这些是最重要的!展示 LLM 的大脑活动 + LLM_START = "llm_start" # LLM 开始思考 + LLM_THOUGHT = "llm_thought" # LLM 思考内容 ⭐ 核心 + LLM_DECISION = "llm_decision" # LLM 决策 ⭐ 核心 + LLM_ACTION = "llm_action" # LLM 动作 + LLM_OBSERVATION = "llm_observation" # LLM 观察结果 + LLM_COMPLETE = "llm_complete" # LLM 完成 + + # LLM Token 流 (实时输出) THINKING_START = "thinking_start" # 开始思考 - THINKING_TOKEN = "thinking_token" # 思考 Token + THINKING_TOKEN = "thinking_token" # 思考 Token (流式) THINKING_END = "thinking_end" # 思考结束 - # 工具调用相关 + # 工具调用相关 - LLM 决定调用工具 TOOL_CALL_START = "tool_call_start" # 工具调用开始 TOOL_CALL_INPUT = "tool_call_input" # 工具输入参数 TOOL_CALL_OUTPUT = "tool_call_output" # 工具输出结果 diff --git a/frontend/src/pages/AgentAudit.tsx b/frontend/src/pages/AgentAudit.tsx index 326ce39..aa911f2 100644 --- a/frontend/src/pages/AgentAudit.tsx +++ b/frontend/src/pages/AgentAudit.tsx @@ -4,13 +4,13 @@ * 支持 LLM 思考过程和工具调用的实时流式展示 */ -import { useState, useEffect, useRef, useCallback } from "react"; +import { useState, useEffect, useRef, useCallback, useMemo } from "react"; import { useParams, useNavigate } from "react-router-dom"; import { - Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2, + Terminal, Bot, Shield, AlertTriangle, CheckCircle2, Loader2, Code, Zap, Activity, ChevronRight, XCircle, - FileCode, Search, Bug, Lock, Play, Square, RefreshCw, - ArrowLeft, Download, ExternalLink, Brain, Wrench, + FileCode, Search, Bug, Square, RefreshCw, + ArrowLeft, ExternalLink, Brain, Wrench, ChevronDown, ChevronUp, Clock, Sparkles } from "lucide-react"; import { Button } from "@/components/ui/button"; @@ -26,42 +26,77 @@ import { getAgentEvents, getAgentFindings, cancelAgentTask, - streamAgentEvents, } from "@/shared/api/agentTasks"; -// 事件类型图标映射 +// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件 const eventTypeIcons: Record = { + // 🧠 LLM 核心事件 - 最重要! + llm_start: , + llm_thought: , + llm_decision: , + llm_action: , + llm_observation: , + llm_complete: , + + // 阶段相关 phase_start: , phase_complete: , - thinking: , - tool_call: , + thinking: , + + // 工具相关 - LLM 决定的工具调用 + tool_call: , tool_result: , tool_error: , + + // 发现相关 + finding: , finding_new: , finding_verified: , + + // 状态相关 info: , warning: , error: , progress: , + + // 任务相关 task_complete: , task_error: , task_cancel: , }; -// 事件类型颜色映射 +// 事件类型颜色映射 - 🔥 LLM 事件突出显示 const eventTypeColors: Record = { + // 🧠 LLM 核心事件 - 使用紫色系突出 + llm_start: "text-purple-400 font-semibold", + llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮 + llm_decision: "text-yellow-300 font-semibold", // 决策特别突出 + llm_action: "text-orange-300 font-medium", + llm_observation: "text-blue-300", + llm_complete: "text-green-400 font-semibold", + + // 阶段相关 phase_start: "text-cyan-400 font-bold", phase_complete: "text-green-400", thinking: "text-purple-300", + + // 工具相关 tool_call: "text-yellow-300", tool_result: "text-green-300", tool_error: "text-red-400", + + // 发现相关 + finding: "text-orange-300 font-medium", finding_new: "text-orange-300", - finding_verified: "text-red-300", + finding_verified: "text-red-300 font-medium", + + // 状态相关 info: "text-gray-300", warning: "text-yellow-300", error: "text-red-400", progress: "text-cyan-300", + + // 任务相关 task_complete: "text-green-400 font-bold", task_error: "text-red-400 font-bold", task_cancel: "text-yellow-400", @@ -99,30 +134,6 @@ export default function AgentAuditPage() { const eventsEndRef = useRef(null); const thinkingEndRef = useRef(null); - const abortControllerRef = useRef(null); - - // 使用增强版流式 Hook - const { - thinking, - isThinking, - toolCalls, - currentPhase: streamPhase, - progress: streamProgress, - connect: connectStream, - disconnect: disconnectStream, - isConnected: isStreamConnected, - } = useAgentStream(taskId || null, { - includeThinking: true, - includeToolCalls: true, - onFinding: () => loadFindings(), - onComplete: () => { - loadTask(); - loadFindings(); - }, - onError: (err) => { - console.error("Stream error:", err); - }, - }); // 是否完成 const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled"; @@ -134,11 +145,16 @@ export default function AgentAuditPage() { try { const taskData = await getAgentTask(taskId); setTask(taskData); - } catch (error) { + } catch (error: any) { console.error("Failed to load task:", error); - toast.error("加载任务失败"); + const errorMessage = error?.response?.data?.detail || error?.message || "未知错误"; + toast.error(`加载任务失败: ${errorMessage}`); + // 如果是 404,可能是任务不存在 + if (error?.response?.status === 404) { + setTimeout(() => navigate("/tasks"), 2000); + } } - }, [taskId]); + }, [taskId, navigate]); // 加载事件 const loadEvents = useCallback(async () => { @@ -164,6 +180,32 @@ export default function AgentAuditPage() { } }, [taskId]); + // 🔥 稳定化回调函数,避免重复创建 connect/disconnect + const streamOptions = useMemo(() => ({ + includeThinking: true, + includeToolCalls: true, + onFinding: () => loadFindings(), + onComplete: () => { + loadTask(); + loadFindings(); + }, + onError: (err: string) => { + console.error("Stream error:", err); + }, + }), [loadFindings, loadTask]); + + // 使用增强版流式 Hook + const { + thinking, + isThinking, + toolCalls, + // currentPhase: streamPhase, // 暂未使用 + // progress: streamProgress, // 暂未使用 + connect: connectStream, + disconnect: disconnectStream, + isConnected: isStreamConnected, + } = useAgentStream(taskId || null, streamOptions); + // 初始化加载 useEffect(() => { const init = async () => { @@ -175,10 +217,11 @@ export default function AgentAuditPage() { init(); }, [loadTask, loadEvents, loadFindings]); - // 连接增强版流式 API + // 🔥 使用增强版流式 API(优先) useEffect(() => { if (!taskId || isComplete || isLoading) return; + // 连接流式 API connectStream(); setIsStreaming(true); @@ -186,49 +229,44 @@ export default function AgentAuditPage() { disconnectStream(); setIsStreaming(false); }; - }, [taskId, isComplete, isLoading, connectStream, disconnectStream]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [taskId, isComplete, isLoading]); // connectStream/disconnectStream 是稳定的,不需要作为依赖 - // 旧版事件流(作为后备) + // 🔥 后备:如果流式连接失败,使用轮询获取事件(仅作为后备) useEffect(() => { if (!taskId || isComplete || isLoading) return; - const startStreaming = async () => { - abortControllerRef.current = new AbortController(); - + // 如果流式连接已建立,不需要轮询 + if (isStreamConnected) { + return; + } + + // 每 5 秒轮询一次事件(作为后备机制) + const pollInterval = setInterval(async () => { try { const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0; + const newEvents = await getAgentEvents(taskId, { after_sequence: lastSequence, limit: 50 }); - for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) { + if (newEvents.length > 0) { setEvents(prev => { - // 避免重复 - if (prev.some(e => e.id === event.id)) return prev; - return [...prev, event]; + // 合并新事件,避免重复 + const existingIds = new Set(prev.map(e => e.id)); + const uniqueNew = newEvents.filter(e => !existingIds.has(e.id)); + return [...prev, ...uniqueNew]; }); - // 如果是发现事件,刷新发现列表 - if (event.event_type.startsWith("finding_")) { - loadFindings(); - } - - // 如果是结束事件,刷新任务状态 - if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) { - loadTask(); + // 如果有发现事件,刷新发现列表 + if (newEvents.some(e => e.event_type.startsWith("finding_"))) { loadFindings(); } } } catch (error) { - if ((error as Error).name !== "AbortError") { - console.error("Event stream error:", error); - } + console.error("Failed to poll events:", error); } - }; + }, 5000); - startStreaming(); - - return () => { - abortControllerRef.current?.abort(); - }; - }, [taskId, isComplete, isLoading, loadTask, loadFindings]); + return () => clearInterval(pollInterval); + }, [taskId, isComplete, isLoading, isStreamConnected, events.length, loadFindings]); // 自动滚动 useEffect(() => { @@ -244,17 +282,17 @@ export default function AgentAuditPage() { return () => clearInterval(interval); }, []); - // 定期轮询任务状态(作为 SSE 的后备机制) + // 定期轮询任务状态(作为 SSE 的后备机制)- 🔥 增加间隔,避免资源耗尽 useEffect(() => { if (!taskId || isComplete || isLoading) return; - // 每 3 秒轮询一次任务状态 + // 🔥 每 10 秒轮询一次(而不是 3 秒),减少资源消耗 const pollInterval = setInterval(async () => { try { const taskData = await getAgentTask(taskId); setTask(taskData); - // 如果任务已完成/失败/取消,刷新其他数据 + // 如果任务已完成/失败/取消,刷新其他数据并停止轮询 if (taskData.status === "completed" || taskData.status === "failed" || taskData.status === "cancelled") { await loadEvents(); await loadFindings(); @@ -262,11 +300,13 @@ export default function AgentAuditPage() { } } catch (error) { console.error("Failed to poll task status:", error); + // 🔥 如果连续失败,停止轮询避免资源耗尽 + clearInterval(pollInterval); } - }, 3000); + }, 10000); // 🔥 改为 10 秒 return () => clearInterval(pollInterval); - }, [taskId, isComplete, isLoading, loadEvents, loadFindings]); + }, [taskId, isComplete, isLoading]); // 🔥 移除函数依赖 // 取消任务 const handleCancel = async () => { @@ -371,57 +411,73 @@ export default function AgentAuditPage() { {/* 左侧:执行日志 */}
- {/* 思考过程展示区域 */} + {/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */} {(isThinking || thinking) && showThinking && ( -
+
setShowThinking(!showThinking)} > -
- - AI Thinking +
+
+ +
+
+ 🧠 LLM Thinking + Agent 的大脑正在工作 +
{isThinking && ( - + - Processing... + 思考中... )}
- {showThinking ? : } + {showThinking ? : }
-
-
- {thinking || "正在思考..."} - {isThinking && } +
+
+ {thinking || "🤔 正在思考下一步..."} + {isThinking && }
)} - {/* 工具调用展示区域 */} + {/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */} {toolCalls.length > 0 && showToolDetails && ( -
+
setShowToolDetails(!showToolDetails)} > -
- - Tool Calls - - {toolCalls.length} +
+
+ +
+
+ 🔧 LLM Tool Calls + LLM 决定调用的工具 +
+ + {toolCalls.length} 次调用
- {showToolDetails ? : } + {showToolDetails ? : }
-
-
+
+
{toolCalls.slice(-5).map((tc, idx) => ( - + | undefined + }} + /> ))}
@@ -429,17 +485,20 @@ export default function AgentAuditPage() { )}
-
- - Execution Log +
+
+ + LLM Execution Log +
+ LLM 思考 & 工具调用记录 {(isStreaming || isStreamConnected) && ( - + LIVE )}
- {events.length} events + {events.length} 条记录
{/* 终端窗口 */} @@ -649,7 +708,7 @@ function StatusBadge({ status }: { status: string }) { ); } -// 事件行组件 +// 事件行组件 - 增强 LLM 事件展示 function EventLine({ event }: { event: AgentEvent }) { const icon = eventTypeIcons[event.event_type] || ; const colorClass = eventTypeColors[event.event_type] || "text-gray-400"; @@ -658,13 +717,28 @@ function EventLine({ event }: { event: AgentEvent }) { ? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false }) : ""; + // LLM 思考事件特殊处理 - 展示多行内容 + const isLLMThought = event.event_type === "llm_thought"; + const isLLMDecision = event.event_type === "llm_decision"; + const isLLMAction = event.event_type === "llm_action"; + const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction; + + // LLM 事件背景色 + const bgClass = isLLMThought + ? "bg-purple-950/40 border-l-2 border-purple-600" + : isLLMDecision + ? "bg-yellow-950/30 border-l-2 border-yellow-600" + : isLLMAction + ? "bg-orange-950/30 border-l-2 border-orange-600" + : ""; + return ( -
+
{timestamp} {icon} - + {event.message} {event.tool_duration_ms && ( ({event.tool_duration_ms}ms) @@ -679,7 +753,7 @@ interface ToolCallProps { toolCall: { name: string; input: Record; - output?: unknown; + output?: string | Record; durationMs?: number; status: 'running' | 'success' | 'error'; }; diff --git a/frontend/src/shared/api/agentStream.ts b/frontend/src/shared/api/agentStream.ts index 82fa28d..ba92fc3 100644 --- a/frontend/src/shared/api/agentStream.ts +++ b/frontend/src/shared/api/agentStream.ts @@ -106,6 +106,9 @@ export class AgentStreamHandler { private reconnectDelay = 1000; private isConnected = false; private thinkingBuffer: string[] = []; + private reader: ReadableStreamDefaultReader | null = null; // 🔥 保存 reader 引用 + private abortController: AbortController | null = null; // 🔥 用于取消请求 + private isDisconnecting = false; // 🔥 标记是否正在断开 constructor(taskId: string, options: StreamOptions = {}) { this.taskId = taskId; @@ -121,6 +124,11 @@ export class AgentStreamHandler { * 开始监听事件流 */ connect(): void { + // 🔥 如果已经连接,不重复连接 + if (this.isConnected || this.isDisconnecting) { + return; + } + const token = localStorage.getItem('access_token'); if (!token) { this.options.onError?.('未登录'); @@ -142,14 +150,23 @@ export class AgentStreamHandler { * 使用 fetch 连接(支持自定义 headers) */ private async connectWithFetch(token: string, params: URLSearchParams): Promise { + // 🔥 如果正在断开,不连接 + if (this.isDisconnecting) { + return; + } + const url = `/api/v1/agent-tasks/${this.taskId}/stream?${params}`; + // 🔥 创建 AbortController 用于取消请求 + this.abortController = new AbortController(); + try { const response = await fetch(url, { headers: { 'Authorization': `Bearer ${token}`, 'Accept': 'text/event-stream', }, + signal: this.abortController.signal, // 🔥 支持取消 }); if (!response.ok) { @@ -159,8 +176,8 @@ export class AgentStreamHandler { this.isConnected = true; this.reconnectAttempts = 0; - const reader = response.body?.getReader(); - if (!reader) { + this.reader = response.body?.getReader() || null; + if (!this.reader) { throw new Error('无法获取响应流'); } @@ -168,7 +185,12 @@ export class AgentStreamHandler { let buffer = ''; while (true) { - const { done, value } = await reader.read(); + // 🔥 检查是否正在断开 + if (this.isDisconnecting) { + break; + } + + const { done, value } = await this.reader.read(); if (done) { break; @@ -184,17 +206,42 @@ export class AgentStreamHandler { this.handleEvent(event); } } - } catch (error) { + + // 🔥 正常结束,清理 reader + if (this.reader) { + this.reader.releaseLock(); + this.reader = null; + } + } catch (error: any) { + // 🔥 如果是取消错误,不处理 + if (error.name === 'AbortError') { + return; + } + this.isConnected = false; console.error('Stream connection error:', error); - // 尝试重连 - if (this.reconnectAttempts < this.maxReconnectAttempts) { + // 🔥 只有在未断开时才尝试重连 + if (!this.isDisconnecting && this.reconnectAttempts < this.maxReconnectAttempts) { this.reconnectAttempts++; - setTimeout(() => this.connect(), this.reconnectDelay * this.reconnectAttempts); + setTimeout(() => { + if (!this.isDisconnecting) { + this.connect(); + } + }, this.reconnectDelay * this.reconnectAttempts); } else { this.options.onError?.(`连接失败: ${error}`); } + } finally { + // 🔥 清理 reader + if (this.reader) { + try { + this.reader.releaseLock(); + } catch { + // 忽略释放错误 + } + this.reader = null; + } } } @@ -357,11 +404,35 @@ export class AgentStreamHandler { * 断开连接 */ disconnect(): void { + // 🔥 标记正在断开,防止重连 + this.isDisconnecting = true; this.isConnected = false; + + // 🔥 取消 fetch 请求 + if (this.abortController) { + this.abortController.abort(); + this.abortController = null; + } + + // 🔥 清理 reader + if (this.reader) { + try { + this.reader.cancel(); + this.reader.releaseLock(); + } catch { + // 忽略清理错误 + } + this.reader = null; + } + + // 清理 EventSource(如果使用) if (this.eventSource) { this.eventSource.close(); this.eventSource = null; } + + // 重置重连计数 + this.reconnectAttempts = 0; } /**