diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 502d4fa..66a58a1 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -28,6 +28,7 @@ from app.models.agent_task import ( ) from app.models.project import Project from app.models.user import User +from app.models.user_config import UserConfig from app.services.agent import AgentRunner, EventManager, run_agent_task from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType @@ -199,7 +200,7 @@ class TaskSummaryResponse(BaseModel): # ============ 后台任务执行 ============ -async def _execute_agent_task(task_id: str, project_root: str): +async def _execute_agent_task(task_id: str): """在后台执行 Agent 任务""" async with async_session_factory() as db: try: @@ -209,14 +210,57 @@ async def _execute_agent_task(task_id: str, project_root: str): logger.error(f"Task {task_id} not found") return + # 获取项目 + project = task.project + if not project: + logger.error(f"Project not found for task {task_id}") + return + + # 🔥 获取项目根目录(解压 ZIP 或克隆仓库) + project_root = await _get_project_root(project, task_id) + + # 🔥 获取用户配置(从系统配置页面) + # 优先级:1. 数据库用户配置 > 2. 环境变量配置 + user_config = None + if task.created_by: + from app.api.v1.endpoints.config import ( + decrypt_config, + SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS + ) + import json + + result = await db.execute( + select(UserConfig).where(UserConfig.user_id == task.created_by) + ) + config = result.scalar_one_or_none() + + if config and config.llm_config: + # 🔥 有数据库配置:使用数据库配置(优先) + user_llm_config = json.loads(config.llm_config) if config.llm_config else {} + user_other_config = json.loads(config.other_config) if config.other_config else {} + + # 解密敏感字段 + user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) + user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) + + user_config = { + "llmConfig": user_llm_config, # 直接使用数据库配置,不合并默认值 + "otherConfig": user_other_config, + } + logger.info(f"✅ Using database user config for task {task_id}, LLM provider: {user_llm_config.get('llmProvider', 'N/A')}") + else: + # 🔥 无数据库配置:传递 None,让 LLMService 使用环境变量 + user_config = None + logger.info(f"⚠️ No database config found for user {task.created_by}, will use environment variables for task {task_id}") + # 更新状态为运行中 task.status = AgentTaskStatus.RUNNING task.started_at = datetime.now(timezone.utc) await db.commit() logger.info(f"Task {task_id} started") - # 创建 Runner - runner = AgentRunner(db, task, project_root) + # 创建 Runner(传入用户配置) + runner = AgentRunner(db, task, project_root, user_config=user_config) _running_tasks[task_id] = runner # 执行 @@ -296,11 +340,8 @@ async def create_agent_task( await db.commit() await db.refresh(task) - # 确定项目根目录 - project_root = _get_project_root(project, task.id) - - # 在后台启动任务 - background_tasks.add_task(_execute_agent_task, task.id, project_root) + # 在后台启动任务(项目根目录在任务内部获取) + background_tasks.add_task(_execute_agent_task, task.id) logger.info(f"Created agent task {task.id} for project {project.name}") @@ -897,24 +938,73 @@ async def update_finding_status( # ============ Helper Functions ============ -def _get_project_root(project: Project, task_id: str) -> str: +async def _get_project_root(project: Project, task_id: str) -> str: """ 获取项目根目录 - TODO: 实际实现中需要: - - 对于 ZIP 项目:解压到临时目录 - - 对于 Git 仓库:克隆到临时目录 + 支持两种项目类型: + - ZIP 项目:解压 ZIP 文件到临时目录 + - 仓库项目:克隆仓库到临时目录 """ + import zipfile + import subprocess + base_path = f"/tmp/deepaudit/{task_id}" # 确保目录存在 os.makedirs(base_path, exist_ok=True) - # 如果项目有存储路径,复制过来 - if hasattr(project, 'storage_path') and project.storage_path: - if os.path.exists(project.storage_path): - # 复制项目文件 - shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True) + # 根据项目类型处理 + if project.source_type == "zip": + # 🔥 ZIP 项目:解压 ZIP 文件 + from app.services.zip_storage import load_project_zip + + zip_path = await load_project_zip(project.id) + + if zip_path and os.path.exists(zip_path): + try: + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(base_path) + logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}") + except Exception as e: + logger.error(f"Failed to extract ZIP {zip_path}: {e}") + else: + logger.warning(f"⚠️ ZIP file not found for project {project.id}") + + elif project.source_type == "repository" and project.repository_url: + # 🔥 仓库项目:克隆仓库 + try: + branch = project.default_branch or "main" + repo_url = project.repository_url + + # 克隆仓库 + result = subprocess.run( + ["git", "clone", "--depth", "1", "--branch", branch, repo_url, base_path], + capture_output=True, + text=True, + timeout=300, + ) + + if result.returncode == 0: + logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}") + else: + logger.warning(f"Failed to clone branch {branch}, trying default branch: {result.stderr}") + # 如果克隆失败,尝试使用默认分支 + if branch != "main": + result = subprocess.run( + ["git", "clone", "--depth", "1", repo_url, base_path], + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode == 0: + logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}") + else: + logger.error(f"Failed to clone repository: {result.stderr}") + except subprocess.TimeoutExpired: + logger.error(f"Git clone timeout for {project.repository_url}") + except Exception as e: + logger.error(f"Failed to clone repository {project.repository_url}: {e}") return base_path diff --git a/backend/app/services/agent/agents/__init__.py b/backend/app/services/agent/agents/__init__.py index 3009b64..fafec70 100644 --- a/backend/app/services/agent/agents/__init__.py +++ b/backend/app/services/agent/agents/__init__.py @@ -1,9 +1,13 @@ """ 混合 Agent 架构 包含 Orchestrator、Recon、Analysis 和 Verification Agent + +协作机制: +- Agent 之间通过 TaskHandoff 传递结构化上下文 +- 每个 Agent 完成后生成 handoff 给下一个 Agent """ -from .base import BaseAgent, AgentConfig, AgentResult +from .base import BaseAgent, AgentConfig, AgentResult, TaskHandoff from .orchestrator import OrchestratorAgent from .recon import ReconAgent from .analysis import AnalysisAgent @@ -13,6 +17,7 @@ __all__ = [ "BaseAgent", "AgentConfig", "AgentResult", + "TaskHandoff", "OrchestratorAgent", "ReconAgent", "AnalysisAgent", diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py index 8b0b86e..ec9ef14 100644 --- a/backend/app/services/agent/agents/analysis.py +++ b/backend/app/services/agent/agents/analysis.py @@ -10,6 +10,7 @@ LLM 是真正的安全分析大脑! 类型: ReAct (真正的!) """ +import asyncio import json import logging import re @@ -17,6 +18,7 @@ from typing import List, Dict, Any, Optional from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from ..json_parser import AgentJsonParser logger = logging.getLogger(__name__) @@ -33,18 +35,13 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自 ## 你可以使用的工具 -### 外部扫描工具 -- **semgrep_scan**: Semgrep 静态分析(推荐首先使用) - 参数: rules (str), max_results (int) -- **bandit_scan**: Python 安全扫描 - -### RAG 语义搜索 -- **rag_query**: 语义代码搜索 - 参数: query (str), top_k (int) -- **security_search**: 安全相关代码搜索 - 参数: vulnerability_type (str), top_k (int) -- **function_context**: 函数上下文分析 - 参数: function_name (str) +### 文件操作 +- **read_file**: 读取文件内容 + 参数: file_path (str), start_line (int), end_line (int) +- **list_files**: 列出目录文件 + 参数: directory (str), pattern (str) +- **search_code**: 代码关键字搜索 + 参数: keyword (str), max_results (int) ### 深度分析 - **pattern_match**: 危险模式匹配 @@ -53,16 +50,28 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自 参数: code (str), file_path (str), focus (str) - **dataflow_analysis**: 数据流追踪 参数: source (str), sink (str) -- **vulnerability_validation**: 漏洞验证 - 参数: code (str), vulnerability_type (str) -### 文件操作 -- **read_file**: 读取文件内容 - 参数: file_path (str), start_line (int), end_line (int) -- **search_code**: 代码关键字搜索 - 参数: keyword (str), max_results (int) -- **list_files**: 列出目录文件 - 参数: directory (str), pattern (str) +### 外部静态分析工具 +- **semgrep_scan**: Semgrep 静态分析(推荐首先使用) + 参数: rules (str), max_results (int) +- **bandit_scan**: Python 安全扫描 + 参数: target (str) +- **gitleaks_scan**: Git 密钥泄露扫描 + 参数: target (str) +- **trufflehog_scan**: 敏感信息扫描 + 参数: target (str) +- **npm_audit**: NPM 依赖漏洞扫描 + 参数: target (str) +- **safety_scan**: Python 依赖安全扫描 + 参数: target (str) +- **osv_scan**: OSV 漏洞数据库扫描 + 参数: target (str) + +### RAG 语义搜索 +- **security_search**: 安全相关代码搜索 + 参数: vulnerability_type (str), top_k (int) +- **function_context**: 函数上下文分析 + 参数: function_name (str) ## 工作方式 每一步,你需要输出: @@ -168,15 +177,7 @@ class AnalysisAgent(BaseAgent): self._conversation_history: List[Dict[str, str]] = [] self._steps: List[AnalysisStep] = [] - def _get_tools_description(self) -> str: - """生成工具描述""" - tools_info = [] - for name, tool in self.tools.items(): - if name.startswith("_"): - continue - desc = f"- {name}: {getattr(tool, 'description', 'No description')}" - tools_info.append(desc) - return "\n".join(tools_info) + def _parse_llm_response(self, response: str) -> AnalysisStep: """解析 LLM 响应""" @@ -191,13 +192,20 @@ class AnalysisAgent(BaseAgent): final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) if final_match: step.is_final = True - try: - answer_text = final_match.group(1).strip() - answer_text = re.sub(r'```json\s*', '', answer_text) - answer_text = re.sub(r'```\s*', '', answer_text) - step.final_answer = json.loads(answer_text) - except json.JSONDecodeError: - step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()} + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + # 使用增强的 JSON 解析器 + step.final_answer = AgentJsonParser.parse( + answer_text, + default={"findings": [], "raw_answer": answer_text} + ) + # 确保 findings 格式正确 + if "findings" in step.final_answer: + step.final_answer["findings"] = [ + f for f in step.final_answer["findings"] + if isinstance(f, dict) + ] return step # 提取 Action @@ -211,51 +219,15 @@ class AnalysisAgent(BaseAgent): input_text = input_match.group(1).strip() input_text = re.sub(r'```json\s*', '', input_text) input_text = re.sub(r'```\s*', '', input_text) - try: - step.action_input = json.loads(input_text) - except json.JSONDecodeError: - step.action_input = {"raw_input": input_text} + # 使用增强的 JSON 解析器 + step.action_input = AgentJsonParser.parse( + input_text, + default={"raw_input": input_text} + ) return step - async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: - """执行工具""" - tool = self.tools.get(tool_name) - - if not tool: - return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" - - try: - self._tool_calls += 1 - await self.emit_tool_call(tool_name, tool_input) - - import time - start = time.time() - - result = await tool.execute(**tool_input) - - duration_ms = int((time.time() - start) * 1000) - await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) - - if result.success: - output = str(result.data) - - # 如果是代码分析工具,也包含 metadata - if result.metadata: - if "issues" in result.metadata: - output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}" - if "findings" in result.metadata: - output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}" - - if len(output) > 6000: - output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]" - return output - else: - return f"工具执行失败: {result.error}" - - except Exception as e: - logger.error(f"Tool execution error: {e}") - return f"工具执行错误: {str(e)}" + async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ @@ -271,6 +243,14 @@ class AnalysisAgent(BaseAgent): task = input_data.get("task", "") task_context = input_data.get("task_context", "") + # 🔥 处理交接信息 + handoff = input_data.get("handoff") + if handoff: + from .base import TaskHandoff + if isinstance(handoff, dict): + handoff = TaskHandoff.from_dict(handoff) + self.receive_handoff(handoff) + # 从 Recon 结果获取上下文 recon_data = previous_results.get("recon", {}) if isinstance(recon_data, dict) and "data" in recon_data: @@ -281,7 +261,9 @@ class AnalysisAgent(BaseAgent): high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", [])) initial_findings = recon_data.get("initial_findings", []) - # 构建初始消息 + # 🔥 构建包含交接上下文的初始消息 + handoff_context = self.get_handoff_context() + initial_message = f"""请开始对项目进行安全漏洞分析。 ## 项目信息 @@ -289,7 +271,7 @@ class AnalysisAgent(BaseAgent): - 语言: {tech_stack.get('languages', [])} - 框架: {tech_stack.get('frameworks', [])} -## 上下文信息 +{handoff_context if handoff_context else f'''## 上下文信息 ### 高风险区域 {json.dumps(high_risk_areas[:20], ensure_ascii=False)} @@ -297,7 +279,7 @@ class AnalysisAgent(BaseAgent): {json.dumps(entry_points[:10], ensure_ascii=False, indent=2)} ### 初步发现 (如果有) -{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else '无'} +{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else "无"}'''} ## 任务 {task_context or task or '进行全面的安全漏洞分析,发现代码中的安全问题。'} @@ -306,9 +288,12 @@ class AnalysisAgent(BaseAgent): {config.get('target_vulnerabilities', ['all'])} ## 可用工具 -{self._get_tools_description()} +{self.get_tools_description()} 请开始你的安全分析。首先思考分析策略,然后选择合适的工具开始分析。""" + + # 🔥 记录工作开始 + self.record_work("开始安全漏洞分析") # 初始化对话历史 self._conversation_history = [ @@ -328,18 +313,22 @@ class AnalysisAgent(BaseAgent): self._iteration = iteration + 1 - # 🔥 发射 LLM 开始思考事件 - await self.emit_llm_start(iteration + 1) + # 🔥 再次检查取消标志(在LLM调用之前) + if self.is_cancelled: + await self.emit_thinking("🛑 任务已取消,停止执行") + break - # 🔥 调用 LLM 进行思考和决策 - response = await self.llm_service.chat_completion_raw( - messages=self._conversation_history, - temperature=0.1, - max_tokens=2048, - ) + # 调用 LLM 进行思考和决策(流式输出) + try: + llm_output, tokens_this_round = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + except asyncio.CancelledError: + logger.info(f"[{self.name}] LLM call cancelled") + break - llm_output = response.get("content", "") - tokens_this_round = response.get("usage", {}).get("total_tokens", 0) self._total_tokens += tokens_this_round # 解析 LLM 响应 @@ -369,6 +358,14 @@ class AnalysisAgent(BaseAgent): finding.get("vulnerability_type", "other"), finding.get("file_path", "") ) + # 🔥 记录洞察 + self.add_insight( + f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}" + ) + + # 🔥 记录工作完成 + self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞") + await self.emit_llm_complete( f"分析完成,发现 {len(all_findings)} 个潜在漏洞", self._total_tokens @@ -380,7 +377,7 @@ class AnalysisAgent(BaseAgent): # 🔥 发射 LLM 动作决策事件 await self.emit_llm_action(step.action, step.action_input or {}) - observation = await self._execute_tool( + observation = await self.execute_tool( step.action, step.action_input or {} ) @@ -427,7 +424,7 @@ class AnalysisAgent(BaseAgent): await self.emit_event( "info", - f"🎯 Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" + f"Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" ) return AgentResult( diff --git a/backend/app/services/agent/agents/analysis_v2.py b/backend/app/services/agent/agents/analysis_v2.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index d526c24..ca7b37c 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -2,14 +2,20 @@ Agent 基类 定义 Agent 的基本接口和通用功能 -核心原则:LLM 是 Agent 的大脑,所有日志应该反映 LLM 的参与! +核心原则: +1. LLM 是 Agent 的大脑,全程参与决策 +2. Agent 之间通过 TaskHandoff 传递结构化上下文 +3. 事件分为流式事件(前端展示)和持久化事件(数据库记录) """ from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, AsyncGenerator +from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple from dataclasses import dataclass, field from enum import Enum +from datetime import datetime, timezone +import asyncio import logging +import uuid logger = logging.getLogger(__name__) @@ -73,6 +79,9 @@ class AgentResult: # 元数据 metadata: Dict[str, Any] = field(default_factory=dict) + # 🔥 协作信息 - Agent 传递给下一个 Agent 的结构化信息 + handoff: Optional["TaskHandoff"] = None + def to_dict(self) -> Dict[str, Any]: return { "success": self.success, @@ -83,9 +92,139 @@ class AgentResult: "tokens_used": self.tokens_used, "duration_ms": self.duration_ms, "metadata": self.metadata, + "handoff": self.handoff.to_dict() if self.handoff else None, } +@dataclass +class TaskHandoff: + """ + 任务交接协议 - Agent 之间传递的结构化信息 + + 设计原则: + 1. 包含足够的上下文让下一个 Agent 理解前序工作 + 2. 提供明确的建议和关注点 + 3. 可直接转换为 LLM 可理解的 prompt + """ + # 基本信息 + from_agent: str + to_agent: str + + # 工作摘要 + summary: str + work_completed: List[str] = field(default_factory=list) + + # 关键发现和洞察 + key_findings: List[Dict[str, Any]] = field(default_factory=list) + insights: List[str] = field(default_factory=list) + + # 建议和关注点 + suggested_actions: List[Dict[str, Any]] = field(default_factory=list) + attention_points: List[str] = field(default_factory=list) + priority_areas: List[str] = field(default_factory=list) + + # 上下文数据 + context_data: Dict[str, Any] = field(default_factory=dict) + + # 置信度 + confidence: float = 0.8 + + # 时间戳 + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> Dict[str, Any]: + return { + "from_agent": self.from_agent, + "to_agent": self.to_agent, + "summary": self.summary, + "work_completed": self.work_completed, + "key_findings": self.key_findings, + "insights": self.insights, + "suggested_actions": self.suggested_actions, + "attention_points": self.attention_points, + "priority_areas": self.priority_areas, + "context_data": self.context_data, + "confidence": self.confidence, + "timestamp": self.timestamp.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TaskHandoff": + return cls( + from_agent=data.get("from_agent", ""), + to_agent=data.get("to_agent", ""), + summary=data.get("summary", ""), + work_completed=data.get("work_completed", []), + key_findings=data.get("key_findings", []), + insights=data.get("insights", []), + suggested_actions=data.get("suggested_actions", []), + attention_points=data.get("attention_points", []), + priority_areas=data.get("priority_areas", []), + context_data=data.get("context_data", {}), + confidence=data.get("confidence", 0.8), + ) + + def to_prompt_context(self) -> str: + """ + 转换为 LLM 可理解的上下文格式 + 这是关键!让 LLM 能够理解前序 Agent 的工作 + """ + lines = [ + f"## 来自 {self.from_agent} Agent 的任务交接", + "", + f"### 工作摘要", + self.summary, + "", + ] + + if self.work_completed: + lines.append("### 已完成的工作") + for work in self.work_completed: + lines.append(f"- {work}") + lines.append("") + + if self.key_findings: + lines.append("### 关键发现") + for i, finding in enumerate(self.key_findings[:15], 1): + severity = finding.get("severity", "medium") + title = finding.get("title", "Unknown") + file_path = finding.get("file_path", "") + lines.append(f"{i}. [{severity.upper()}] {title}") + if file_path: + lines.append(f" 位置: {file_path}:{finding.get('line_start', '')}") + if finding.get("description"): + lines.append(f" 描述: {finding['description'][:100]}") + lines.append("") + + if self.insights: + lines.append("### 洞察和分析") + for insight in self.insights: + lines.append(f"- {insight}") + lines.append("") + + if self.suggested_actions: + lines.append("### 建议的下一步行动") + for action in self.suggested_actions: + action_type = action.get("type", "general") + description = action.get("description", "") + priority = action.get("priority", "medium") + lines.append(f"- [{priority.upper()}] {action_type}: {description}") + lines.append("") + + if self.attention_points: + lines.append("### ⚠️ 需要特别关注") + for point in self.attention_points: + lines.append(f"- {point}") + lines.append("") + + if self.priority_areas: + lines.append("### 优先分析区域") + for area in self.priority_areas: + lines.append(f"- {area}") + + return "\n".join(lines) + + class BaseAgent(ABC): """ Agent 基类 @@ -94,6 +233,11 @@ class BaseAgent(ABC): 1. LLM 是 Agent 的大脑,全程参与决策 2. 所有日志应该反映 LLM 的思考过程 3. 工具调用是 LLM 的决策结果 + + 协作原则: + 1. 通过 TaskHandoff 接收前序 Agent 的上下文 + 2. 执行完成后生成 TaskHandoff 传递给下一个 Agent + 3. 洞察和发现应该结构化记录 """ def __init__( @@ -122,6 +266,11 @@ class BaseAgent(ABC): self._total_tokens = 0 self._tool_calls = 0 self._cancelled = False + + # 🔥 协作状态 + self._incoming_handoff: Optional[TaskHandoff] = None + self._insights: List[str] = [] # 收集的洞察 + self._work_completed: List[str] = [] # 完成的工作记录 @property def name(self) -> str: @@ -152,6 +301,103 @@ class BaseAgent(ABC): def is_cancelled(self) -> bool: return self._cancelled + # ============ 协作方法 ============ + + def receive_handoff(self, handoff: TaskHandoff): + """ + 接收来自前序 Agent 的任务交接 + + Args: + handoff: 任务交接对象 + """ + self._incoming_handoff = handoff + logger.info( + f"[{self.name}] Received handoff from {handoff.from_agent}: " + f"{handoff.summary[:50]}..." + ) + + def get_handoff_context(self) -> str: + """ + 获取交接上下文(用于构建 LLM prompt) + + Returns: + 格式化的上下文字符串 + """ + if not self._incoming_handoff: + return "" + return self._incoming_handoff.to_prompt_context() + + def add_insight(self, insight: str): + """记录洞察""" + self._insights.append(insight) + + def record_work(self, work: str): + """记录完成的工作""" + self._work_completed.append(work) + + def create_handoff( + self, + to_agent: str, + summary: str, + key_findings: List[Dict[str, Any]] = None, + suggested_actions: List[Dict[str, Any]] = None, + attention_points: List[str] = None, + priority_areas: List[str] = None, + context_data: Dict[str, Any] = None, + ) -> TaskHandoff: + """ + 创建任务交接 + + Args: + to_agent: 目标 Agent + summary: 工作摘要 + key_findings: 关键发现 + suggested_actions: 建议的行动 + attention_points: 需要关注的点 + priority_areas: 优先分析区域 + context_data: 上下文数据 + + Returns: + TaskHandoff 对象 + """ + return TaskHandoff( + from_agent=self.name, + to_agent=to_agent, + summary=summary, + work_completed=self._work_completed.copy(), + key_findings=key_findings or [], + insights=self._insights.copy(), + suggested_actions=suggested_actions or [], + attention_points=attention_points or [], + priority_areas=priority_areas or [], + context_data=context_data or {}, + ) + + def build_prompt_with_handoff(self, base_prompt: str) -> str: + """ + 构建包含交接上下文的 prompt + + Args: + base_prompt: 基础 prompt + + Returns: + 增强后的 prompt + """ + handoff_context = self.get_handoff_context() + if not handoff_context: + return base_prompt + + return f"""{base_prompt} + +--- +## 前序 Agent 交接信息 + +{handoff_context} + +--- +请基于以上来自前序 Agent 的信息,结合你的专业能力开展工作。 +""" + # ============ 核心事件发射方法 ============ async def emit_event( @@ -173,13 +419,13 @@ class BaseAgent(ABC): async def emit_thinking(self, message: str): """发射 LLM 思考事件""" - await self.emit_event("thinking", f"🧠 [{self.name}] {message}") + await self.emit_event("thinking", f"[{self.name}] {message}") async def emit_llm_start(self, iteration: int): """发射 LLM 开始思考事件""" await self.emit_event( "llm_start", - f"🤔 [{self.name}] LLM 开始第 {iteration} 轮思考...", + f"[{self.name}] 第 {iteration} 轮迭代开始", metadata={"iteration": iteration} ) @@ -189,31 +435,62 @@ class BaseAgent(ABC): display_thought = thought[:500] + "..." if len(thought) > 500 else thought await self.emit_event( "llm_thought", - f"💭 [{self.name}] LLM 思考:\n{display_thought}", + f"[{self.name}] 思考: {display_thought}", metadata={ "thought": thought, "iteration": iteration, } ) + async def emit_thinking_start(self): + """发射开始思考事件(流式输出用)""" + await self.emit_event("thinking_start", f"[{self.name}] 开始思考...") + + async def emit_thinking_token(self, token: str, accumulated: str): + """发射思考 token 事件(流式输出用)""" + await self.emit_event( + "thinking_token", + "", # 不需要 message,前端从 metadata 获取 + metadata={ + "token": token, + "accumulated": accumulated, + } + ) + + async def emit_thinking_end(self, full_response: str): + """发射思考结束事件(流式输出用)""" + await self.emit_event( + "thinking_end", + f"[{self.name}] 思考完成", + metadata={"accumulated": full_response} + ) + async def emit_llm_decision(self, decision: str, reason: str = ""): """发射 LLM 决策事件 - 展示 LLM 做了什么决定""" await self.emit_event( "llm_decision", - f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""), + f"[{self.name}] 决策: {decision}" + (f" ({reason})" if reason else ""), metadata={ "decision": decision, "reason": reason, } ) + async def emit_llm_complete(self, result_summary: str, tokens_used: int): + """发射 LLM 完成事件""" + await self.emit_event( + "llm_complete", + f"[{self.name}] 完成: {result_summary} (消耗 {tokens_used} tokens)", + metadata={ + "tokens_used": tokens_used, + } + ) + async def emit_llm_action(self, action: str, action_input: Dict): - """发射 LLM 动作事件 - LLM 决定执行什么动作""" - import json - input_str = json.dumps(action_input, ensure_ascii=False)[:200] + """发射 LLM 动作决策事件""" await self.emit_event( "llm_action", - f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}", + f"[{self.name}] 执行动作: {action}", metadata={ "action": action, "action_input": action_input, @@ -221,43 +498,33 @@ class BaseAgent(ABC): ) async def emit_llm_observation(self, observation: str): - """发射 LLM 观察事件 - LLM 看到了什么""" + """发射 LLM 观察事件""" + # 截断过长的观察结果 display_obs = observation[:300] + "..." if len(observation) > 300 else observation await self.emit_event( "llm_observation", - f"👁️ [{self.name}] LLM 观察到:\n{display_obs}", - metadata={"observation": observation[:2000]} - ) - - async def emit_llm_complete(self, result_summary: str, tokens_used: int): - """发射 LLM 完成事件""" - await self.emit_event( - "llm_complete", - f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)", + f"[{self.name}] 观察结果: {display_obs}", metadata={ - "tokens_used": tokens_used, + "observation": observation[:2000], # 限制存储长度 } ) # ============ 工具调用相关事件 ============ async def emit_tool_call(self, tool_name: str, tool_input: Dict): - """发射工具调用事件 - LLM 决定调用工具""" - import json - input_str = json.dumps(tool_input, ensure_ascii=False)[:300] + """发射工具调用事件""" await self.emit_event( "tool_call", - f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}", + f"[{self.name}] 调用工具: {tool_name}", tool_name=tool_name, tool_input=tool_input, ) async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int): """发射工具结果事件""" - result_preview = result[:200] + "..." if len(result) > 200 else result await self.emit_event( "tool_result", - f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}", + f"[{self.name}] 工具 {tool_name} 完成 ({duration_ms}ms)", tool_name=tool_name, tool_duration_ms=duration_ms, ) @@ -332,9 +599,6 @@ class BaseAgent(ABC): """ self._iteration += 1 - # 发射 LLM 开始事件 - await self.emit_llm_start(self._iteration) - try: response = await self.llm_service.chat_completion( messages=messages, @@ -385,3 +649,124 @@ class BaseAgent(ABC): "tool_calls": self._tool_calls, "tokens_used": self._total_tokens, } + + # ============ 统一的流式 LLM 调用 ============ + + async def stream_llm_call( + self, + messages: List[Dict[str, str]], + temperature: float = 0.1, + max_tokens: int = 2048, + ) -> Tuple[str, int]: + """ + 统一的流式 LLM 调用方法 + + 所有 Agent 共用此方法,避免重复代码 + + Args: + messages: 消息列表 + temperature: 温度 + max_tokens: 最大 token 数 + + Returns: + (完整响应内容, token数量) + """ + accumulated = "" + total_tokens = 0 + + await self.emit_thinking_start() + + try: + async for chunk in self.llm_service.chat_completion_stream( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ): + # 检查取消 + if self.is_cancelled: + break + + if chunk["type"] == "token": + token = chunk["content"] + accumulated = chunk["accumulated"] + await self.emit_thinking_token(token, accumulated) + + elif chunk["type"] == "done": + accumulated = chunk["content"] + if chunk.get("usage"): + total_tokens = chunk["usage"].get("total_tokens", 0) + break + + elif chunk["type"] == "error": + accumulated = chunk.get("accumulated", "") + logger.error(f"Stream error: {chunk.get('error')}") + break + + except asyncio.CancelledError: + logger.info(f"[{self.name}] LLM call cancelled") + raise + finally: + await self.emit_thinking_end(accumulated) + + return accumulated, total_tokens + + async def execute_tool(self, tool_name: str, tool_input: Dict) -> str: + """ + 统一的工具执行方法 + + Args: + tool_name: 工具名称 + tool_input: 工具参数 + + Returns: + 工具执行结果字符串 + """ + tool = self.tools.get(tool_name) + + if not tool: + return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" + + try: + self._tool_calls += 1 + await self.emit_tool_call(tool_name, tool_input) + + import time + start = time.time() + + result = await tool.execute(**tool_input) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + if result.success: + output = str(result.data) + + # 包含 metadata 中的额外信息 + if result.metadata: + if "issues" in result.metadata: + import json + output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}" + if "findings" in result.metadata: + import json + output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}" + + # 截断过长输出 + if len(output) > 6000: + output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]" + return output + else: + return f"工具执行失败: {result.error}" + + except Exception as e: + logger.error(f"Tool execution error: {e}") + return f"工具执行错误: {str(e)}" + + def get_tools_description(self) -> str: + """生成工具描述文本(用于 prompt)""" + tools_info = [] + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + desc = f"- {name}: {getattr(tool, 'description', 'No description')}" + tools_info.append(desc) + return "\n".join(tools_info) diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index 9354ac5..b5ee923 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -18,6 +18,7 @@ from typing import List, Dict, Any, Optional from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from ..json_parser import AgentJsonParser logger = logging.getLogger(__name__) @@ -178,18 +179,22 @@ class OrchestratorAgent(BaseAgent): self._iteration = iteration + 1 - # 🔥 发射 LLM 开始思考事件 - await self.emit_llm_start(iteration + 1) + # 🔥 再次检查取消标志(在LLM调用之前) + if self.is_cancelled: + await self.emit_thinking("🛑 任务已取消,停止执行") + break - # 🔥 调用 LLM 进行思考和决策 - response = await self.llm_service.chat_completion_raw( - messages=self._conversation_history, - temperature=0.1, - max_tokens=2048, - ) + # 调用 LLM 进行思考和决策(流式输出) + try: + llm_output, tokens_this_round = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + except asyncio.CancelledError: + logger.info(f"[{self.name}] LLM call cancelled") + break - llm_output = response.get("content", "") - tokens_this_round = response.get("usage", {}).get("total_tokens", 0) self._total_tokens += tokens_this_round # 解析 LLM 的决策 @@ -348,10 +353,11 @@ class OrchestratorAgent(BaseAgent): input_text = re.sub(r'```json\s*', '', input_text) input_text = re.sub(r'```\s*', '', input_text) - try: - action_input = json.loads(input_text) - except json.JSONDecodeError: - action_input = {"raw": input_text} + # 使用增强的 JSON 解析器 + action_input = AgentJsonParser.parse( + input_text, + default={"raw": input_text} + ) return AgentStep( thought=thought, diff --git a/backend/app/services/agent/agents/react_agent.py b/backend/app/services/agent/agents/react_agent.py index e7f5631..37e8e0e 100644 --- a/backend/app/services/agent/agents/react_agent.py +++ b/backend/app/services/agent/agents/react_agent.py @@ -16,6 +16,7 @@ from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from ..json_parser import AgentJsonParser logger = logging.getLogger(__name__) @@ -182,15 +183,20 @@ class ReActAgent(BaseAgent): final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) if final_match: step.is_final = True - try: - # 尝试提取 JSON - answer_text = final_match.group(1).strip() - # 移除 markdown 代码块 - answer_text = re.sub(r'```json\s*', '', answer_text) - answer_text = re.sub(r'```\s*', '', answer_text) - step.final_answer = json.loads(answer_text) - except json.JSONDecodeError: - step.final_answer = {"raw_answer": final_match.group(1).strip()} + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + # 使用增强的 JSON 解析器 + step.final_answer = AgentJsonParser.parse( + answer_text, + default={"raw_answer": answer_text} + ) + # 确保 findings 格式正确 + if "findings" in step.final_answer: + step.final_answer["findings"] = [ + f for f in step.final_answer["findings"] + if isinstance(f, dict) + ] return step # 提取 Action @@ -202,14 +208,13 @@ class ReActAgent(BaseAgent): input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) if input_match: input_text = input_match.group(1).strip() - # 移除 markdown 代码块 input_text = re.sub(r'```json\s*', '', input_text) input_text = re.sub(r'```\s*', '', input_text) - try: - step.action_input = json.loads(input_text) - except json.JSONDecodeError: - # 尝试简单解析 - step.action_input = {"raw_input": input_text} + # 使用增强的 JSON 解析器 + step.action_input = AgentJsonParser.parse( + input_text, + default={"raw_input": input_text} + ) return step diff --git a/backend/app/services/agent/agents/recon.py b/backend/app/services/agent/agents/recon.py index 8ddc609..5279a99 100644 --- a/backend/app/services/agent/agents/recon.py +++ b/backend/app/services/agent/agents/recon.py @@ -10,6 +10,7 @@ LLM 是真正的大脑! 类型: ReAct (真正的!) """ +import asyncio import json import logging import re @@ -17,22 +18,24 @@ from typing import List, Dict, Any, Optional from dataclasses import dataclass from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from ..json_parser import AgentJsonParser logger = logging.getLogger(__name__) -RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前**自主**收集项目信息。 +RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。 -## 你的角色 -你是信息收集的**大脑**,不是机械执行者。你需要: -1. 自主思考需要收集什么信息 -2. 选择合适的工具获取信息 -3. 根据发现动态调整策略 -4. 判断何时信息收集足够 +## 你的职责 +你专注于**信息收集**,为后续的漏洞分析提供基础数据: +1. 分析项目结构和目录布局 +2. 识别技术栈(语言、框架、数据库) +3. 找出入口点(API、路由、用户输入处理) +4. 标记高风险区域(认证、数据库操作、文件处理) +5. 收集依赖信息 ## 你可以使用的工具 -### 文件系统 +### 文件系统工具 - **list_files**: 列出目录内容 参数: directory (str), recursive (bool), pattern (str), max_files (int) @@ -42,12 +45,14 @@ RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安 - **search_code**: 代码关键字搜索 参数: keyword (str), max_results (int) -### 安全扫描 -- **semgrep_scan**: Semgrep 静态分析扫描 -- **npm_audit**: npm 依赖漏洞审计 -- **safety_scan**: Python 依赖漏洞审计 -- **gitleaks_scan**: 密钥/敏感信息泄露扫描 -- **osv_scan**: OSV 通用依赖漏洞扫描 +### 语义搜索工具 +- **rag_query**: 语义代码搜索(如果可用) + 参数: query (str), top_k (int) + +## 注意 +- 你只负责信息收集,不要进行漏洞分析 +- 漏洞分析由 Analysis Agent 负责 +- 专注于收集项目结构、技术栈、入口点等信息 ## 工作方式 每一步,你需要输出: @@ -142,15 +147,7 @@ class ReconAgent(BaseAgent): self._conversation_history: List[Dict[str, str]] = [] self._steps: List[ReconStep] = [] - def _get_tools_description(self) -> str: - """生成工具描述""" - tools_info = [] - for name, tool in self.tools.items(): - if name.startswith("_"): - continue - desc = f"- {name}: {getattr(tool, 'description', 'No description')}" - tools_info.append(desc) - return "\n".join(tools_info) + def _parse_llm_response(self, response: str) -> ReconStep: """解析 LLM 响应""" @@ -165,13 +162,14 @@ class ReconAgent(BaseAgent): final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) if final_match: step.is_final = True - try: - answer_text = final_match.group(1).strip() - answer_text = re.sub(r'```json\s*', '', answer_text) - answer_text = re.sub(r'```\s*', '', answer_text) - step.final_answer = json.loads(answer_text) - except json.JSONDecodeError: - step.final_answer = {"raw_answer": final_match.group(1).strip()} + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + # 使用增强的 JSON 解析器 + step.final_answer = AgentJsonParser.parse( + answer_text, + default={"raw_answer": answer_text} + ) return step # 提取 Action @@ -185,43 +183,15 @@ class ReconAgent(BaseAgent): input_text = input_match.group(1).strip() input_text = re.sub(r'```json\s*', '', input_text) input_text = re.sub(r'```\s*', '', input_text) - try: - step.action_input = json.loads(input_text) - except json.JSONDecodeError: - step.action_input = {"raw_input": input_text} + # 使用增强的 JSON 解析器 + step.action_input = AgentJsonParser.parse( + input_text, + default={"raw_input": input_text} + ) return step - async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: - """执行工具""" - tool = self.tools.get(tool_name) - - if not tool: - return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" - - try: - self._tool_calls += 1 - await self.emit_tool_call(tool_name, tool_input) - - import time - start = time.time() - - result = await tool.execute(**tool_input) - - duration_ms = int((time.time() - start) * 1000) - await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) - - if result.success: - output = str(result.data) - if len(output) > 4000: - output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]" - return output - else: - return f"工具执行失败: {result.error}" - - except Exception as e: - logger.error(f"Tool execution error: {e}") - return f"工具执行错误: {str(e)}" + async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ @@ -246,7 +216,7 @@ class ReconAgent(BaseAgent): {task_context or task or '进行全面的信息收集,为安全审计做准备。'} ## 可用工具 -{self._get_tools_description()} +{self.get_tools_description()} 请开始你的信息收集工作。首先思考应该收集什么信息,然后选择合适的工具。""" @@ -259,7 +229,7 @@ class ReconAgent(BaseAgent): self._steps = [] final_result = None - await self.emit_thinking("🔍 Recon Agent 启动,LLM 开始自主收集信息...") + await self.emit_thinking("Recon Agent 启动,LLM 开始自主收集信息...") try: for iteration in range(self.config.max_iterations): @@ -268,18 +238,22 @@ class ReconAgent(BaseAgent): self._iteration = iteration + 1 - # 🔥 发射 LLM 开始思考事件 - await self.emit_llm_start(iteration + 1) + # 🔥 再次检查取消标志(在LLM调用之前) + if self.is_cancelled: + await self.emit_thinking("🛑 任务已取消,停止执行") + break - # 🔥 调用 LLM 进行思考和决策 - response = await self.llm_service.chat_completion_raw( - messages=self._conversation_history, - temperature=0.1, - max_tokens=2048, - ) + # 调用 LLM 进行思考和决策(使用基类统一方法) + try: + llm_output, tokens_this_round = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + except asyncio.CancelledError: + logger.info(f"[{self.name}] LLM call cancelled") + break - llm_output = response.get("content", "") - tokens_this_round = response.get("usage", {}).get("total_tokens", 0) self._total_tokens += tokens_this_round # 解析 LLM 响应 @@ -311,7 +285,7 @@ class ReconAgent(BaseAgent): # 🔥 发射 LLM 动作决策事件 await self.emit_llm_action(step.action, step.action_input or {}) - observation = await self._execute_tool( + observation = await self.execute_tool( step.action, step.action_input or {} ) @@ -341,9 +315,18 @@ class ReconAgent(BaseAgent): if not final_result: final_result = self._summarize_from_steps() + # 🔥 记录工作和洞察 + self.record_work(f"完成项目信息收集,发现 {len(final_result.get('entry_points', []))} 个入口点") + self.record_work(f"识别技术栈: {final_result.get('tech_stack', {})}") + + if final_result.get("high_risk_areas"): + self.add_insight(f"发现 {len(final_result['high_risk_areas'])} 个高风险区域需要重点分析") + if final_result.get("initial_findings"): + self.add_insight(f"初步发现 {len(final_result['initial_findings'])} 个潜在问题") + await self.emit_event( "info", - f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" + f"Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" ) return AgentResult( diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py index e4c17bb..02360af 100644 --- a/backend/app/services/agent/agents/verification.py +++ b/backend/app/services/agent/agents/verification.py @@ -10,6 +10,7 @@ LLM 是验证的大脑! 类型: ReAct (真正的!) """ +import asyncio import json import logging import re @@ -18,6 +19,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern +from ..json_parser import AgentJsonParser logger = logging.getLogger(__name__) @@ -34,15 +36,17 @@ VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,一个* ## 你可以使用的工具 -### 代码分析 +### 文件操作 - **read_file**: 读取更多代码上下文 参数: file_path (str), start_line (int), end_line (int) -- **function_context**: 分析函数调用关系 - 参数: function_name (str) -- **dataflow_analysis**: 追踪数据流 - 参数: source (str), sink (str), file_path (str) +- **list_files**: 列出目录文件 + 参数: directory (str), pattern (str) + +### 验证分析 - **vulnerability_validation**: LLM 深度验证 ⭐ 参数: code (str), vulnerability_type (str), context (str) +- **dataflow_analysis**: 追踪数据流 + 参数: source (str), sink (str), file_path (str) ### 沙箱验证 - **sandbox_exec**: 在沙箱中执行命令 @@ -157,16 +161,6 @@ class VerificationAgent(BaseAgent): self._conversation_history: List[Dict[str, str]] = [] self._steps: List[VerificationStep] = [] - def _get_tools_description(self) -> str: - """生成工具描述""" - tools_info = [] - for name, tool in self.tools.items(): - if name.startswith("_"): - continue - desc = f"- {name}: {getattr(tool, 'description', 'No description')}" - tools_info.append(desc) - return "\n".join(tools_info) - def _parse_llm_response(self, response: str) -> VerificationStep: """解析 LLM 响应""" step = VerificationStep(thought="") @@ -180,13 +174,20 @@ class VerificationAgent(BaseAgent): final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) if final_match: step.is_final = True - try: - answer_text = final_match.group(1).strip() - answer_text = re.sub(r'```json\s*', '', answer_text) - answer_text = re.sub(r'```\s*', '', answer_text) - step.final_answer = json.loads(answer_text) - except json.JSONDecodeError: - step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()} + answer_text = final_match.group(1).strip() + answer_text = re.sub(r'```json\s*', '', answer_text) + answer_text = re.sub(r'```\s*', '', answer_text) + # 使用增强的 JSON 解析器 + step.final_answer = AgentJsonParser.parse( + answer_text, + default={"findings": [], "raw_answer": answer_text} + ) + # 确保 findings 格式正确 + if "findings" in step.final_answer: + step.final_answer["findings"] = [ + f for f in step.final_answer["findings"] + if isinstance(f, dict) + ] return step # 提取 Action @@ -200,50 +201,14 @@ class VerificationAgent(BaseAgent): input_text = input_match.group(1).strip() input_text = re.sub(r'```json\s*', '', input_text) input_text = re.sub(r'```\s*', '', input_text) - try: - step.action_input = json.loads(input_text) - except json.JSONDecodeError: - step.action_input = {"raw_input": input_text} + # 使用增强的 JSON 解析器 + step.action_input = AgentJsonParser.parse( + input_text, + default={"raw_input": input_text} + ) return step - async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: - """执行工具""" - tool = self.tools.get(tool_name) - - if not tool: - return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" - - try: - self._tool_calls += 1 - await self.emit_tool_call(tool_name, tool_input) - - import time - start = time.time() - - result = await tool.execute(**tool_input) - - duration_ms = int((time.time() - start) * 1000) - await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) - - if result.success: - output = str(result.data) - - # 包含 metadata - if result.metadata: - if "validation" in result.metadata: - output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}" - - if len(output) > 4000: - output = output[:4000] + f"\n\n... [输出已截断]" - return output - else: - return f"工具执行失败: {result.error}" - - except Exception as e: - logger.error(f"Tool execution error: {e}") - return f"工具执行错误: {str(e)}" - async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ 执行漏洞验证 - LLM 全程参与! @@ -256,20 +221,32 @@ class VerificationAgent(BaseAgent): task = input_data.get("task", "") task_context = input_data.get("task_context", "") + # 🔥 处理交接信息 + handoff = input_data.get("handoff") + if handoff: + from .base import TaskHandoff + if isinstance(handoff, dict): + handoff = TaskHandoff.from_dict(handoff) + self.receive_handoff(handoff) + # 收集所有待验证的发现 findings_to_verify = [] - for phase_name, result in previous_results.items(): - if isinstance(result, dict): - data = result.get("data", {}) - else: - data = result.data if hasattr(result, 'data') else {} - - if isinstance(data, dict): - phase_findings = data.get("findings", []) - for f in phase_findings: - if f.get("needs_verification", True): - findings_to_verify.append(f) + # 🔥 优先从交接信息获取发现 + if self._incoming_handoff and self._incoming_handoff.key_findings: + findings_to_verify = self._incoming_handoff.key_findings.copy() + else: + for phase_name, result in previous_results.items(): + if isinstance(result, dict): + data = result.get("data", {}) + else: + data = result.data if hasattr(result, 'data') else {} + + if isinstance(data, dict): + phase_findings = data.get("findings", []) + for f in phase_findings: + if f.get("needs_verification", True): + findings_to_verify.append(f) # 去重 findings_to_verify = self._deduplicate(findings_to_verify) @@ -289,7 +266,12 @@ class VerificationAgent(BaseAgent): f"开始验证 {len(findings_to_verify)} 个发现" ) - # 构建初始消息 + # 🔥 记录工作开始 + self.record_work(f"开始验证 {len(findings_to_verify)} 个漏洞发现") + + # 🔥 构建包含交接上下文的初始消息 + handoff_context = self.get_handoff_context() + findings_summary = [] for i, f in enumerate(findings_to_verify): findings_summary.append(f""" @@ -306,6 +288,8 @@ class VerificationAgent(BaseAgent): initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。 +{handoff_context if handoff_context else ''} + ## 待验证发现 {''.join(findings_summary)} @@ -313,9 +297,10 @@ class VerificationAgent(BaseAgent): - 验证级别: {config.get('verification_level', 'standard')} ## 可用工具 -{self._get_tools_description()} +{self.get_tools_description()} -请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。""" +请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。 +{f"特别注意 Analysis Agent 提到的关注点。" if handoff_context else ""}""" # 初始化对话历史 self._conversation_history = [ @@ -335,18 +320,22 @@ class VerificationAgent(BaseAgent): self._iteration = iteration + 1 - # 🔥 发射 LLM 开始思考事件 - await self.emit_llm_start(iteration + 1) + # 🔥 再次检查取消标志(在LLM调用之前) + if self.is_cancelled: + await self.emit_thinking("🛑 任务已取消,停止执行") + break - # 🔥 调用 LLM 进行思考和决策 - response = await self.llm_service.chat_completion_raw( - messages=self._conversation_history, - temperature=0.1, - max_tokens=3000, - ) + # 调用 LLM 进行思考和决策(流式输出) + try: + llm_output, tokens_this_round = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=3000, + ) + except asyncio.CancelledError: + logger.info(f"[{self.name}] LLM call cancelled") + break - llm_output = response.get("content", "") - tokens_this_round = response.get("usage", {}).get("total_tokens", 0) self._total_tokens += tokens_this_round # 解析 LLM 响应 @@ -367,6 +356,14 @@ class VerificationAgent(BaseAgent): if step.is_final: await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分") final_result = step.final_answer + + # 🔥 记录洞察和工作 + if final_result and "findings" in final_result: + verified_count = len([f for f in final_result["findings"] if f.get("is_verified")]) + fp_count = len([f for f in final_result["findings"] if f.get("verdict") == "false_positive"]) + self.add_insight(f"验证了 {len(final_result['findings'])} 个发现,{verified_count} 个确认,{fp_count} 个误报") + self.record_work(f"完成漏洞验证: {verified_count} 个确认, {fp_count} 个误报") + await self.emit_llm_complete( f"验证完成", self._total_tokens @@ -378,7 +375,7 @@ class VerificationAgent(BaseAgent): # 🔥 发射 LLM 动作决策事件 await self.emit_llm_action(step.action, step.action_input or {}) - observation = await self._execute_tool( + observation = await self.execute_tool( step.action, step.action_input or {} ) @@ -438,7 +435,7 @@ class VerificationAgent(BaseAgent): await self.emit_event( "info", - f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报" + f"Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报" ) return AgentResult( diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py index b1ead15..fe99e6e 100644 --- a/backend/app/services/agent/event_manager.py +++ b/backend/app/services/agent/event_manager.py @@ -299,8 +299,9 @@ class EventManager: "timestamp": timestamp.isoformat(), } - # 保存到数据库 - if self.db_session_factory: + # 保存到数据库(跳过高频事件如 thinking_token) + skip_db_events = {"thinking_token", "thinking_start", "thinking_end"} + if self.db_session_factory and event_type not in skip_db_events: try: await self._save_event_to_db(event_data) except Exception as e: diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py index 8f62338..90e0e8d 100644 --- a/backend/app/services/agent/graph/audit_graph.py +++ b/backend/app/services/agent/graph/audit_graph.py @@ -69,6 +69,11 @@ class AuditState(TypedDict): llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end" llm_routing_reason: Optional[str] # LLM 的决策理由 + # 🔥 新增:Agent 间协作的任务交接信息 + recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接 + analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接 + verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接 + # 消息和事件 messages: Annotated[List[Dict], operator.add] events: Annotated[List[Dict], operator.add] @@ -146,6 +151,9 @@ class LLMRouter: # 统计发现 severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} for f in findings: + # 跳过非字典类型的 finding + if not isinstance(f, dict): + continue sev = f.get("severity", "medium") severity_counts[sev] = severity_counts.get(sev, 0) + 1 @@ -243,6 +251,11 @@ def route_after_recon(state: AuditState) -> Literal["analysis", "end"]: Recon 后的路由决策 优先使用 LLM 的决策,否则使用默认逻辑 """ + # 🔥 检查是否有错误 + if state.get("error") or state.get("current_phase") == "error": + logger.error(f"Recon phase has error, routing to end: {state.get('error')}") + return "end" + # 检查 LLM 是否有决策 llm_action = state.get("llm_next_action") if llm_action: diff --git a/backend/app/services/agent/graph/nodes.py b/backend/app/services/agent/graph/nodes.py index bcf5825..4b750ef 100644 --- a/backend/app/services/agent/graph/nodes.py +++ b/backend/app/services/agent/graph/nodes.py @@ -1,6 +1,8 @@ """ LangGraph 节点实现 每个节点封装一个 Agent 的执行逻辑 + +协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察 """ from typing import Dict, Any, List, Optional @@ -28,6 +30,14 @@ class BaseNode: await self.event_emitter.emit_info(message) except Exception as e: logger.warning(f"Failed to emit event: {e}") + + def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str): + """从状态中提取前序 Agent 的 handoff""" + handoff_data = state.get(f"{from_phase}_handoff") + if handoff_data: + from ..agents.base import TaskHandoff + return TaskHandoff.from_dict(handoff_data) + return None class ReconNode(BaseNode): @@ -35,7 +45,7 @@ class ReconNode(BaseNode): 信息收集节点 输入: project_root, project_info, config - 输出: tech_stack, entry_points, high_risk_areas, dependencies + 输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff """ async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: @@ -52,6 +62,35 @@ class ReconNode(BaseNode): if result.success and result.data: data = result.data + # 🔥 创建交接信息给 Analysis Agent + handoff = self.agent.create_handoff( + to_agent="Analysis", + summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。", + key_findings=data.get("initial_findings", []), + suggested_actions=[ + { + "type": "deep_analysis", + "description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}", + "priority": "high", + }, + { + "type": "entry_point_audit", + "description": "审计所有入口点的输入验证", + "priority": "high", + }, + ], + attention_points=[ + f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}", + f"主要语言: {data.get('tech_stack', {}).get('languages', [])}", + ], + priority_areas=data.get("high_risk_areas", [])[:10], + context_data={ + "tech_stack": data.get("tech_stack", {}), + "entry_points": data.get("entry_points", []), + "dependencies": data.get("dependencies", {}), + }, + ) + await self.emit_event( "phase_complete", f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点" @@ -63,12 +102,15 @@ class ReconNode(BaseNode): "high_risk_areas": data.get("high_risk_areas", []), "dependencies": data.get("dependencies", {}), "current_phase": "recon_complete", - "findings": data.get("initial_findings", []), # 初步发现 + "findings": data.get("initial_findings", []), + # 🔥 保存交接信息 + "recon_handoff": handoff.to_dict(), "events": [{ "type": "recon_complete", "data": { "entry_points_count": len(data.get("entry_points", [])), "high_risk_areas_count": len(data.get("high_risk_areas", [])), + "handoff_summary": handoff.summary, } }], } @@ -90,8 +132,8 @@ class AnalysisNode(BaseNode): """ 漏洞分析节点 - 输入: tech_stack, entry_points, high_risk_areas, previous findings - 输出: findings (累加), should_continue_analysis + 输入: tech_stack, entry_points, high_risk_areas, recon_handoff + 输出: findings (累加), should_continue_analysis, analysis_handoff """ async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: @@ -104,6 +146,15 @@ class AnalysisNode(BaseNode): ) try: + # 🔥 提取 Recon 的交接信息 + recon_handoff = self._extract_handoff_from_state(state, "recon") + if recon_handoff: + self.agent.receive_handoff(recon_handoff) + await self.emit_event( + "handoff_received", + f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..." + ) + # 构建分析输入 analysis_input = { "phase_name": "analysis", @@ -121,6 +172,8 @@ class AnalysisNode(BaseNode): } } }, + # 🔥 传递交接信息 + "handoff": recon_handoff, } # 调用 Analysis Agent @@ -130,27 +183,70 @@ class AnalysisNode(BaseNode): new_findings = result.data.get("findings", []) # 判断是否需要继续分析 - # 如果这一轮发现了很多问题,可能还有更多 should_continue = ( len(new_findings) >= 5 and iteration < state.get("max_iterations", 3) ) + # 🔥 创建交接信息给 Verification Agent + # 统计严重程度 + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} + for f in new_findings: + if isinstance(f, dict): + sev = f.get("severity", "medium") + severity_counts[sev] = severity_counts.get(sev, 0) + 1 + + handoff = self.agent.create_handoff( + to_agent="Verification", + summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})", + key_findings=new_findings[:20], # 传递前20个发现 + suggested_actions=[ + { + "type": "verify_critical", + "description": "优先验证 Critical 和 High 级别的漏洞", + "priority": "critical", + }, + { + "type": "poc_generation", + "description": "为确认的漏洞生成 PoC", + "priority": "high", + }, + ], + attention_points=[ + f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证", + f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证", + "注意检查是否有误报,特别是静态分析工具的结果", + ], + priority_areas=[ + f.get("file_path", "") for f in new_findings + if f.get("severity") in ["critical", "high"] + ][:10], + context_data={ + "severity_distribution": severity_counts, + "total_findings": len(new_findings), + "iteration": iteration, + }, + ) + await self.emit_event( "phase_complete", f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞" ) return { - "findings": new_findings, # 会自动累加 + "findings": new_findings, "iteration": iteration, "should_continue_analysis": should_continue, "current_phase": "analysis_complete", + # 🔥 保存交接信息 + "analysis_handoff": handoff.to_dict(), "events": [{ "type": "analysis_iteration", "data": { "iteration": iteration, "findings_count": len(new_findings), + "severity_distribution": severity_counts, + "handoff_summary": handoff.summary, } }], } @@ -174,8 +270,8 @@ class VerificationNode(BaseNode): """ 漏洞验证节点 - 输入: findings - 输出: verified_findings, false_positives + 输入: findings, analysis_handoff + 输出: verified_findings, false_positives, verification_handoff """ async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: @@ -195,6 +291,15 @@ class VerificationNode(BaseNode): ) try: + # 🔥 提取 Analysis 的交接信息 + analysis_handoff = self._extract_handoff_from_state(state, "analysis") + if analysis_handoff: + self.agent.receive_handoff(analysis_handoff) + await self.emit_event( + "handoff_received", + f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..." + ) + # 构建验证输入 verification_input = { "previous_results": { @@ -205,16 +310,49 @@ class VerificationNode(BaseNode): } }, "config": state["config"], + # 🔥 传递交接信息 + "handoff": analysis_handoff, } # 调用 Verification Agent result = await self.agent.run(verification_input) if result.success and result.data: - verified = [f for f in result.data.get("findings", []) if f.get("is_verified")] - false_pos = [f["id"] for f in result.data.get("findings", []) + all_verified_findings = result.data.get("findings", []) + verified = [f for f in all_verified_findings if f.get("is_verified")] + false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings if f.get("verdict") == "false_positive"] + # 🔥 创建交接信息给 Report 节点 + handoff = self.agent.create_handoff( + to_agent="Report", + summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。", + key_findings=verified, + suggested_actions=[ + { + "type": "generate_report", + "description": "生成详细的安全审计报告", + "priority": "high", + }, + { + "type": "remediation_plan", + "description": "为确认的漏洞制定修复计划", + "priority": "high", + }, + ], + attention_points=[ + f"共 {len(verified)} 个漏洞已确认存在", + f"共 {len(false_pos)} 个误报已排除", + "建议按严重程度优先修复 Critical 和 High 级别漏洞", + ], + context_data={ + "verified_count": len(verified), + "false_positive_count": len(false_pos), + "total_analyzed": len(findings), + "verification_rate": len(verified) / len(findings) if findings else 0, + }, + ) + await self.emit_event( "phase_complete", f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报" @@ -224,11 +362,14 @@ class VerificationNode(BaseNode): "verified_findings": verified, "false_positives": false_pos, "current_phase": "verification_complete", + # 🔥 保存交接信息 + "verification_handoff": handoff.to_dict(), "events": [{ "type": "verification_complete", "data": { "verified_count": len(verified), "false_positive_count": len(false_pos), + "handoff_summary": handoff.summary, } }], } @@ -269,6 +410,11 @@ class ReportNode(BaseNode): type_counts = {} for finding in findings: + # 跳过非字典类型的 finding(防止数据格式异常) + if not isinstance(finding, dict): + logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}") + continue + sev = finding.get("severity", "medium") severity_counts[sev] = severity_counts.get(sev, 0) + 1 @@ -300,7 +446,7 @@ class ReportNode(BaseNode): await self.emit_event( "phase_complete", - f"✅ 报告生成完成: 安全评分 {security_score}/100" + f"报告生成完成: 安全评分 {security_score}/100" ) return { diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index 304121d..3299c2e5 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -39,167 +39,8 @@ from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode logger = logging.getLogger(__name__) -class LLMService: - """ - LLM 服务封装 - 提供代码分析、漏洞检测等 AI 功能 - """ - - def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None): - self.model = model or settings.LLM_MODEL or "gpt-4o-mini" - self.api_key = api_key or settings.LLM_API_KEY - self.base_url = settings.LLM_BASE_URL - - async def chat_completion_raw( - self, - messages: List[Dict[str, str]], - temperature: float = 0.1, - max_tokens: int = 4096, - ) -> Dict[str, Any]: - """调用 LLM 生成响应""" - try: - import litellm - - response = await litellm.acompletion( - model=self.model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - api_key=self.api_key, - base_url=self.base_url, - ) - - return { - "content": response.choices[0].message.content, - "usage": { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } if response.usage else {}, - } - - except Exception as e: - logger.error(f"LLM call failed: {e}") - raise - - async def analyze_code(self, code: str, language: str) -> Dict[str, Any]: - """ - 分析代码安全问题 - - Args: - code: 代码内容 - language: 编程语言 - - Returns: - 分析结果,包含 issues 列表 - """ - prompt = f"""请分析以下 {language} 代码的安全问题。 - -代码: -```{language} -{code[:8000]} -``` - -请识别所有潜在的安全漏洞,包括但不限于: -- SQL 注入 -- XSS (跨站脚本) -- 命令注入 -- 路径遍历 -- 不安全的反序列化 -- 硬编码密钥/密码 -- 不安全的加密 -- SSRF -- 认证/授权问题 - -对于每个发现的问题,请提供: -1. 漏洞类型 -2. 严重程度 (critical/high/medium/low) -3. 问题描述 -4. 具体行号 -5. 修复建议 - -请以 JSON 格式返回结果: -{{ - "issues": [ - {{ - "type": "漏洞类型", - "severity": "严重程度", - "title": "问题标题", - "description": "详细描述", - "line": 行号, - "code_snippet": "相关代码片段", - "suggestion": "修复建议" - }} - ], - "quality_score": 0-100 -}} - -如果没有发现安全问题,返回空的 issues 数组和较高的 quality_score。""" - - try: - result = await self.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"}, - {"role": "user", "content": prompt}, - ], - temperature=0.1, - max_tokens=4096, - ) - - content = result.get("content", "{}") - - # 尝试提取 JSON - import json - import re - - # 尝试直接解析 - try: - return json.loads(content) - except json.JSONDecodeError: - pass - - # 尝试从 markdown 代码块提取 - json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content) - if json_match: - try: - return json.loads(json_match.group(1)) - except json.JSONDecodeError: - pass - - # 返回空结果 - return {"issues": [], "quality_score": 80} - - except Exception as e: - logger.error(f"Code analysis failed: {e}") - return {"issues": [], "quality_score": 0, "error": str(e)} - - async def analyze_code_with_custom_prompt( - self, - code: str, - language: str, - prompt: str, - **kwargs - ) -> Dict[str, Any]: - """使用自定义提示词分析代码""" - full_prompt = prompt.replace("{code}", code).replace("{language}", language) - - try: - result = await self.chat_completion_raw( - messages=[ - {"role": "system", "content": "你是一位专业的代码安全审计专家。"}, - {"role": "user", "content": full_prompt}, - ], - temperature=0.1, - ) - - return { - "analysis": result.get("content", ""), - "usage": result.get("usage", {}), - } - - except Exception as e: - logger.error(f"Custom analysis failed: {e}") - return {"analysis": "", "error": str(e)} +# 🔥 使用系统统一的 LLMService(支持用户配置) +from app.services.llm.service import LLMService class AgentRunner: @@ -217,18 +58,22 @@ class AgentRunner: db: AsyncSession, task: AgentTask, project_root: str, + user_config: Optional[Dict[str, Any]] = None, ): self.db = db self.task = task self.project_root = project_root + # 🔥 保存用户配置,供 RAG 初始化使用 + self.user_config = user_config or {} + # 事件管理 - 传入 db_session_factory 以持久化事件 from app.db.session import async_session_factory self.event_manager = EventManager(db_session_factory=async_session_factory) self.event_emitter = AgentEventEmitter(task.id, self.event_manager) - # LLM 服务 - self.llm_service = LLMService() + # 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取) + self.llm_service = LLMService(user_config=self.user_config) # 工具集 self.tools: Dict[str, Any] = {} @@ -248,14 +93,26 @@ class AgentRunner: self._cancelled = False self._running_task: Optional[asyncio.Task] = None + # Agent 引用(用于取消传播) + self._agents: List[Any] = [] + # 流式处理器 self.stream_handler = StreamHandler(task.id) def cancel(self): """取消任务""" self._cancelled = True + + # 🔥 取消所有 Agent + for agent in self._agents: + if hasattr(agent, 'cancel'): + agent.cancel() + logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}") + + # 取消运行中的任务 if self._running_task and not self._running_task.done(): self._running_task.cancel() + logger.info(f"Task {self.task.id} cancellation requested") @property @@ -283,11 +140,33 @@ class AgentRunner: await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...") try: + # 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key) + # 优先级:用户配置 > 环境变量 + user_llm_config = self.user_config.get('llmConfig', {}) + + # 获取 Embedding 配置(优先使用用户配置的 LLM API Key) + embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai') + embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') + + # 🔥 API Key 优先级:用户配置 > 环境变量 + embedding_api_key = ( + user_llm_config.get('llmApiKey') or + getattr(settings, 'LLM_API_KEY', '') or + '' + ) + + # 🔥 Base URL 优先级:用户配置 > 环境变量 + embedding_base_url = ( + user_llm_config.get('llmBaseUrl') or + getattr(settings, 'LLM_BASE_URL', None) or + None + ) + embedding_service = EmbeddingService( - provider=settings.EMBEDDING_PROVIDER, - model=settings.EMBEDDING_MODEL, - api_key=settings.LLM_API_KEY, - base_url=settings.LLM_BASE_URL, + provider=embedding_provider, + model=embedding_model, + api_key=embedding_api_key, + base_url=embedding_base_url, ) self.indexer = CodeIndexer( @@ -308,35 +187,59 @@ class AgentRunner: async def _initialize_tools(self): """初始化工具集""" - await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...") + await self.event_emitter.emit_info("初始化 Agent 工具集...") - # 文件工具 - self.tools["read_file"] = FileReadTool(self.project_root) - self.tools["search_code"] = FileSearchTool(self.project_root) - self.tools["list_files"] = ListFilesTool(self.project_root) + # ============ 基础工具(所有 Agent 共享)============ + base_tools = { + "read_file": FileReadTool(self.project_root), + "list_files": ListFilesTool(self.project_root), + } - # RAG 工具 + # ============ Recon Agent 专属工具 ============ + # 职责:信息收集、项目结构分析、技术栈识别 + self.recon_tools = { + **base_tools, + "search_code": FileSearchTool(self.project_root), + } + + # RAG 工具(Recon 用于语义搜索) if self.retriever: - self.tools["rag_query"] = RAGQueryTool(self.retriever) - self.tools["security_search"] = SecurityCodeSearchTool(self.retriever) - self.tools["function_context"] = FunctionContextTool(self.retriever) + self.recon_tools["rag_query"] = RAGQueryTool(self.retriever) - # 分析工具 - self.tools["pattern_match"] = PatternMatchTool(self.project_root) - self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service) - self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service) - self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service) + # ============ Analysis Agent 专属工具 ============ + # 职责:漏洞分析、代码审计、模式匹配 + self.analysis_tools = { + **base_tools, + "search_code": FileSearchTool(self.project_root), + # 模式匹配和代码分析 + "pattern_match": PatternMatchTool(self.project_root), + "code_analysis": CodeAnalysisTool(self.llm_service), + "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), + # 外部静态分析工具 + "semgrep_scan": SemgrepTool(self.project_root), + "bandit_scan": BanditTool(self.project_root), + "gitleaks_scan": GitleaksTool(self.project_root), + "trufflehog_scan": TruffleHogTool(self.project_root), + "npm_audit": NpmAuditTool(self.project_root), + "safety_scan": SafetyTool(self.project_root), + "osv_scan": OSVScannerTool(self.project_root), + } - # 外部安全工具 - self.tools["semgrep_scan"] = SemgrepTool(self.project_root) - self.tools["bandit_scan"] = BanditTool(self.project_root) - self.tools["gitleaks_scan"] = GitleaksTool(self.project_root) - self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root) - self.tools["npm_audit"] = NpmAuditTool(self.project_root) - self.tools["safety_scan"] = SafetyTool(self.project_root) - self.tools["osv_scan"] = OSVScannerTool(self.project_root) + # RAG 工具(Analysis 用于安全相关代码搜索) + if self.retriever: + self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever) + self.analysis_tools["function_context"] = FunctionContextTool(self.retriever) - # 沙箱工具 + # ============ Verification Agent 专属工具 ============ + # 职责:漏洞验证、PoC 执行、误报排除 + self.verification_tools = { + **base_tools, + # 验证工具 + "vulnerability_validation": VulnerabilityValidationTool(self.llm_service), + "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), + } + + # 沙箱工具(仅 Verification Agent 可用) try: self.sandbox_manager = SandboxManager( image=settings.SANDBOX_IMAGE, @@ -344,14 +247,20 @@ class AgentRunner: cpu_limit=settings.SANDBOX_CPU_LIMIT, ) - self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) - self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) - self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) + self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) + self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) + self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) except Exception as e: logger.warning(f"Sandbox initialization failed: {e}") - await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具") + # 统计总工具数 + total_tools = len(set( + list(self.recon_tools.keys()) + + list(self.analysis_tools.keys()) + + list(self.verification_tools.keys()) + )) + await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具") async def _build_graph(self): """构建 LangGraph 审计图""" @@ -360,25 +269,28 @@ class AgentRunner: # 导入 Agent from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent - # 创建 Agent 实例 + # 创建 Agent 实例(每个 Agent 使用专属工具集) recon_agent = ReconAgent( llm_service=self.llm_service, - tools=self.tools, + tools=self.recon_tools, # Recon 专属工具 event_emitter=self.event_emitter, ) analysis_agent = AnalysisAgent( llm_service=self.llm_service, - tools=self.tools, + tools=self.analysis_tools, # Analysis 专属工具 event_emitter=self.event_emitter, ) verification_agent = VerificationAgent( llm_service=self.llm_service, - tools=self.tools, + tools=self.verification_tools, # Verification 专属工具 event_emitter=self.event_emitter, ) + # 🔥 保存 Agent 引用以便取消时传播信号 + self._agents = [recon_agent, analysis_agent, verification_agent] + # 创建节点 recon_node = ReconNode(recon_agent, self.event_emitter) analysis_node = AnalysisNode(analysis_agent, self.event_emitter) @@ -481,6 +393,10 @@ class AgentRunner: "iteration": 0, "max_iterations": self.task.max_iterations or 50, "should_continue_analysis": False, + # 🔥 Agent 协作交接信息 + "recon_handoff": None, + "analysis_handoff": None, + "verification_handoff": None, "messages": [], "events": [], "summary": None, @@ -556,6 +472,33 @@ class AgentRunner: graph_state = self.graph.get_state(run_config) final_state = graph_state.values if graph_state else {} + # 🔥 检查是否有错误 + error = final_state.get("error") + if error: + # 检查是否是 LLM 认证错误 + error_str = str(error) + if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str: + error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。" + logger.error(f"LLM authentication error: {error}") + else: + error_message = error_str + + duration_ms = int((time.time() - start_time) * 1000) + + # 标记任务为失败 + await self._update_task_status(AgentTaskStatus.FAILED, error_message) + await self.event_emitter.emit_task_error(error_message) + + yield StreamEvent( + event_type=StreamEventType.TASK_ERROR, + sequence=self.stream_handler._next_sequence(), + data={ + "error": error_message, + "message": f"❌ 任务失败: {error_message}", + }, + ) + return + # 6. 保存发现 findings = final_state.get("findings", []) await self._save_findings(findings) diff --git a/backend/app/services/agent/json_parser.py b/backend/app/services/agent/json_parser.py new file mode 100644 index 0000000..ec9c2cd --- /dev/null +++ b/backend/app/services/agent/json_parser.py @@ -0,0 +1,251 @@ +""" +Agent JSON 解析工具 +从 LLM 响应中安全地解析 JSON,参考 llm/service.py 的实现 +""" + +import json +import re +import logging +from typing import Dict, Any, List, Optional, Union + +logger = logging.getLogger(__name__) + +# 尝试导入 json-repair 库 +try: + from json_repair import repair_json + JSON_REPAIR_AVAILABLE = True +except ImportError: + JSON_REPAIR_AVAILABLE = False + logger.debug("json-repair library not available") + + +class AgentJsonParser: + """Agent 专用的 JSON 解析器""" + + @staticmethod + def clean_text(text: str) -> str: + """清理文本中的控制字符""" + if not text: + return "" + # 移除 BOM 和零宽字符 + text = text.replace('\ufeff', '').replace('\u200b', '').replace('\u200c', '').replace('\u200d', '') + return text + + @staticmethod + def fix_json_format(text: str) -> str: + """修复常见的 JSON 格式问题""" + text = text.strip() + # 移除尾部逗号 + text = re.sub(r',(\s*[}\]])', r'\1', text) + # 修复未转义的换行符(在字符串值中) + text = re.sub(r':\s*"([^"]*)\n([^"]*)"', r': "\1\\n\2"', text) + return text + + @classmethod + def extract_from_markdown(cls, text: str) -> Dict[str, Any]: + """从 markdown 代码块提取 JSON""" + match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', text) + if match: + return json.loads(match.group(1)) + raise ValueError("No markdown code block found") + + @classmethod + def extract_json_object(cls, text: str) -> Dict[str, Any]: + """智能提取 JSON 对象""" + start_idx = text.find('{') + if start_idx == -1: + raise ValueError("No JSON object found") + + # 考虑字符串内的花括号和转义字符 + brace_count = 0 + in_string = False + escape_next = False + end_idx = -1 + + for i in range(start_idx, len(text)): + char = text[i] + + if escape_next: + escape_next = False + continue + + if char == '\\': + escape_next = True + continue + + if char == '"' and not escape_next: + in_string = not in_string + continue + + if not in_string: + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0: + end_idx = i + 1 + break + + if end_idx == -1: + # 如果找不到完整的 JSON,尝试使用最后一个 } + last_brace = text.rfind('}') + if last_brace > start_idx: + end_idx = last_brace + 1 + else: + raise ValueError("Incomplete JSON object") + + json_str = text[start_idx:end_idx] + # 修复格式问题 + json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) + + return json.loads(json_str) + + @classmethod + def fix_truncated_json(cls, text: str) -> Dict[str, Any]: + """修复截断的 JSON""" + start_idx = text.find('{') + if start_idx == -1: + raise ValueError("Cannot fix truncated JSON") + + json_str = text[start_idx:] + + # 计算缺失的闭合符号 + open_braces = json_str.count('{') + close_braces = json_str.count('}') + open_brackets = json_str.count('[') + close_brackets = json_str.count(']') + + # 补全缺失的闭合符号 + json_str += ']' * max(0, open_brackets - close_brackets) + json_str += '}' * max(0, open_braces - close_braces) + + # 修复格式 + json_str = re.sub(r',(\s*[}\]])', r'\1', json_str) + return json.loads(json_str) + + @classmethod + def repair_with_library(cls, text: str) -> Dict[str, Any]: + """使用 json-repair 库修复损坏的 JSON""" + if not JSON_REPAIR_AVAILABLE: + raise ValueError("json-repair library not available") + + start_idx = text.find('{') + if start_idx == -1: + raise ValueError("No JSON object found for repair") + + end_idx = text.rfind('}') + if end_idx > start_idx: + json_str = text[start_idx:end_idx + 1] + else: + json_str = text[start_idx:] + + repaired = repair_json(json_str, return_objects=True) + + if isinstance(repaired, dict): + return repaired + + raise ValueError(f"json-repair returned unexpected type: {type(repaired)}") + + @classmethod + def parse(cls, text: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """ + 从 LLM 响应中解析 JSON(增强版) + + Args: + text: LLM 响应文本 + default: 解析失败时返回的默认值,如果为 None 则抛出异常 + + Returns: + 解析后的字典 + """ + if not text or not text.strip(): + if default is not None: + logger.warning("LLM 响应为空,返回默认值") + return default + raise ValueError("LLM 响应内容为空") + + clean = cls.clean_text(text) + + # 尝试多种方式解析 + attempts = [ + ("直接解析", lambda: json.loads(text)), + ("清理后解析", lambda: json.loads(cls.fix_json_format(clean))), + ("Markdown 提取", lambda: cls.extract_from_markdown(text)), + ("智能提取", lambda: cls.extract_json_object(clean)), + ("截断修复", lambda: cls.fix_truncated_json(clean)), + ("json-repair", lambda: cls.repair_with_library(text)), + ] + + last_error = None + for name, attempt in attempts: + try: + result = attempt() + if result and isinstance(result, dict): + if name != "直接解析": + logger.debug(f"✅ JSON 解析成功(方法: {name})") + return result + except Exception as e: + last_error = e + logger.debug(f"JSON 解析方法 '{name}' 失败: {e}") + + # 所有尝试都失败 + if default is not None: + logger.warning(f"JSON 解析失败,返回默认值。原始内容: {text[:200]}...") + return default + + logger.error(f"❌ 无法解析 JSON,原始内容: {text[:500]}...") + raise ValueError(f"无法解析 JSON: {last_error}") + + @classmethod + def parse_findings(cls, text: str) -> List[Dict[str, Any]]: + """ + 专门解析 findings 列表 + + Args: + text: LLM 响应文本 + + Returns: + findings 列表(每个元素都是字典) + """ + try: + result = cls.parse(text, default={"findings": []}) + findings = result.get("findings", []) + + # 确保每个 finding 都是字典 + valid_findings = [] + for f in findings: + if isinstance(f, dict): + valid_findings.append(f) + elif isinstance(f, str): + # 尝试将字符串解析为 JSON + try: + parsed = json.loads(f) + if isinstance(parsed, dict): + valid_findings.append(parsed) + except json.JSONDecodeError: + logger.warning(f"跳过无效的 finding(字符串): {f[:100]}...") + else: + logger.warning(f"跳过无效的 finding(类型: {type(f)})") + + return valid_findings + + except Exception as e: + logger.error(f"解析 findings 失败: {e}") + return [] + + @classmethod + def safe_get(cls, data: Union[Dict, str, Any], key: str, default: Any = None) -> Any: + """ + 安全地从数据中获取值 + + Args: + data: 可能是字典或其他类型 + key: 要获取的键 + default: 默认值 + + Returns: + 获取的值或默认值 + """ + if isinstance(data, dict): + return data.get(key, default) + return default diff --git a/backend/app/services/agent/tools/code_analysis_tool.py b/backend/app/services/agent/tools/code_analysis_tool.py index c97e51e..6d50391 100644 --- a/backend/app/services/agent/tools/code_analysis_tool.py +++ b/backend/app/services/agent/tools/code_analysis_tool.py @@ -79,9 +79,25 @@ class CodeAnalysisTool(AgentTool): **kwargs ) -> ToolResult: """执行代码分析""" + import asyncio + try: - # 构建分析结果 - analysis = await self.llm_service.analyze_code(code, language) + # 限制代码长度,避免超时 + max_code_length = 50000 # 约 50KB + if len(code) > max_code_length: + code = code[:max_code_length] + "\n\n... (代码已截断,仅分析前 50000 字符)" + + # 添加超时保护(5分钟) + try: + analysis = await asyncio.wait_for( + self.llm_service.analyze_code(code, language), + timeout=300.0 # 5分钟超时 + ) + except asyncio.TimeoutError: + return ToolResult( + success=False, + error="代码分析超时(超过5分钟)。代码可能过长或过于复杂,请尝试分析较小的代码片段。", + ) issues = analysis.get("issues", []) diff --git a/backend/app/services/agent/tools/external_tools.py b/backend/app/services/agent/tools/external_tools.py index 8fb3caf..e863ef7 100644 --- a/backend/app/services/agent/tools/external_tools.py +++ b/backend/app/services/agent/tools/external_tools.py @@ -109,10 +109,14 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。 """执行 Semgrep 扫描""" # 检查 semgrep 是否可用 if not await self._check_semgrep(): - return ToolResult( - success=False, - error="Semgrep 未安装。请使用 'pip install semgrep' 安装。", - ) + # 尝试自动安装 + logger.info("Semgrep 未安装,尝试自动安装...") + install_success = await self._try_install_semgrep() + if not install_success: + return ToolResult( + success=False, + error="Semgrep 未安装。请使用 'pip install semgrep' 安装,或联系管理员安装。", + ) # 构建完整路径 full_path = os.path.normpath(os.path.join(self.project_root, target_path)) @@ -216,6 +220,30 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。 return proc.returncode == 0 except: return False + + async def _try_install_semgrep(self) -> bool: + """尝试自动安装 Semgrep""" + try: + logger.info("正在安装 Semgrep...") + proc = await asyncio.create_subprocess_exec( + "pip", "install", "semgrep", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120) + if proc.returncode == 0: + logger.info("Semgrep 安装成功") + # 验证安装 + return await self._check_semgrep() + else: + logger.warning(f"Semgrep 安装失败: {stderr.decode()[:200]}") + return False + except asyncio.TimeoutError: + logger.warning("Semgrep 安装超时") + return False + except Exception as e: + logger.warning(f"Semgrep 安装出错: {e}") + return False # ============ Bandit 工具 (Python) ============ @@ -422,7 +450,11 @@ Gitleaks 是专业的密钥检测工具,支持 150+ 种密钥类型。 if not await self._check_gitleaks(): return ToolResult( success=False, - error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。", + error="Gitleaks 未安装。Gitleaks 需要手动安装,请参考: https://github.com/gitleaks/gitleaks/releases\n" + "安装方法:\n" + "- macOS: brew install gitleaks\n" + "- Linux: 下载二进制文件并添加到 PATH\n" + "- Windows: 下载二进制文件并添加到 PATH", ) full_path = os.path.normpath(os.path.join(self.project_root, target_path)) diff --git a/backend/app/services/agent/tools/pattern_tool.py b/backend/app/services/agent/tools/pattern_tool.py index 09a7da8..37e22ed 100644 --- a/backend/app/services/agent/tools/pattern_tool.py +++ b/backend/app/services/agent/tools/pattern_tool.py @@ -291,8 +291,19 @@ class PatternMatchTool(AgentTool): return f"""快速扫描代码中的危险模式和常见漏洞。 使用正则表达式检测已知的不安全代码模式。 +⚠️ 重要:此工具需要代码内容作为输入,不是目录路径! +使用步骤: +1. 先用 read_file 工具读取文件内容 +2. 然后将读取的代码内容传递给此工具的 code 参数 + 支持的漏洞类型: {vuln_types} +输入参数: +- code (必需): 要扫描的代码内容(字符串) +- file_path (可选): 文件路径,用于上下文 +- pattern_types (可选): 要检测的漏洞类型列表,如 ['sql_injection', 'xss'] +- language (可选): 编程语言,如 'python', 'php', 'javascript' + 这是一个快速扫描工具,可以在分析开始时使用来快速发现潜在问题。 发现的问题需要进一步分析确认。""" diff --git a/backend/app/services/agent/tools/rag_tool.py b/backend/app/services/agent/tools/rag_tool.py index 7c8b9c6..527b95d 100644 --- a/backend/app/services/agent/tools/rag_tool.py +++ b/backend/app/services/agent/tools/rag_tool.py @@ -189,10 +189,27 @@ class SecurityCodeSearchTool(AgentTool): ) except Exception as e: - return ToolResult( - success=False, - error=f"安全代码搜索失败: {str(e)}", - ) + error_msg = str(e) + # 提供更友好的错误信息 + if "401" in error_msg or "Unauthorized" in error_msg: + return ToolResult( + success=False, + error=f"安全代码搜索失败: API 认证失败(401 Unauthorized)。\n" + f"请检查系统配置中的 LLM API Key 是否正确设置。\n" + f"错误详情: {error_msg[:200]}", + ) + elif "403" in error_msg or "Forbidden" in error_msg: + return ToolResult( + success=False, + error=f"安全代码搜索失败: API 访问被拒绝(403 Forbidden)。\n" + f"请检查 API Key 是否有足够的权限。\n" + f"错误详情: {error_msg[:200]}", + ) + else: + return ToolResult( + success=False, + error=f"安全代码搜索失败: {error_msg[:500]}", + ) class FunctionContextInput(BaseModel): diff --git a/backend/app/services/llm/adapters/litellm_adapter.py b/backend/app/services/llm/adapters/litellm_adapter.py index 81a9d90..ea104e6 100644 --- a/backend/app/services/llm/adapters/litellm_adapter.py +++ b/backend/app/services/llm/adapters/litellm_adapter.py @@ -177,6 +177,85 @@ class LiteLLMAdapter(BaseLLMAdapter): finish_reason=choice.finish_reason, ) + async def stream_complete(self, request: LLMRequest): + """ + 流式调用 LLM,逐 token 返回 + + Yields: + dict: {"type": "token", "content": str} 或 {"type": "done", "content": str, "usage": dict} + """ + import litellm + + await self.validate_config() + + litellm.cache = None + litellm.drop_params = True + + messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] + + kwargs = { + "model": self._litellm_model, + "messages": messages, + "temperature": request.temperature if request.temperature is not None else self.config.temperature, + "max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens, + "top_p": request.top_p if request.top_p is not None else self.config.top_p, + "stream": True, # 启用流式输出 + } + + if self.config.api_key and self.config.api_key != "ollama": + kwargs["api_key"] = self.config.api_key + + if self._api_base: + kwargs["api_base"] = self._api_base + + kwargs["timeout"] = self.config.timeout + + accumulated_content = "" + + try: + response = await litellm.acompletion(**kwargs) + + async for chunk in response: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + content = getattr(delta, "content", "") or "" + finish_reason = chunk.choices[0].finish_reason + + if content: + accumulated_content += content + yield { + "type": "token", + "content": content, + "accumulated": accumulated_content, + } + + if finish_reason: + # 流式完成 + usage = None + if hasattr(chunk, "usage") and chunk.usage: + usage = { + "prompt_tokens": chunk.usage.prompt_tokens or 0, + "completion_tokens": chunk.usage.completion_tokens or 0, + "total_tokens": chunk.usage.total_tokens or 0, + } + + yield { + "type": "done", + "content": accumulated_content, + "usage": usage, + "finish_reason": finish_reason, + } + break + + except Exception as e: + yield { + "type": "error", + "error": str(e), + "accumulated": accumulated_content, + } + async def validate_config(self) -> bool: """验证配置""" # Ollama 不需要 API Key diff --git a/backend/app/services/llm/service.py b/backend/app/services/llm/service.py index 610eaef..dee4eec 100644 --- a/backend/app/services/llm/service.py +++ b/backend/app/services/llm/service.py @@ -6,7 +6,7 @@ LLM服务 - 代码分析核心服务 import json import re import logging -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, List from .types import LLMConfig, LLMProvider, LLMMessage, LLMRequest, DEFAULT_MODELS from .factory import LLMFactory from app.core.config import settings @@ -36,15 +36,23 @@ class LLMService: @property def config(self) -> LLMConfig: - """获取LLM配置(优先使用用户配置,然后使用系统配置)""" + """ + 获取LLM配置 + + 🔥 优先级(从高到低): + 1. 数据库用户配置(系统配置页面保存的配置) + 2. 环境变量配置(.env 文件中的配置) + + 如果用户配置中某个字段为空,则自动回退到环境变量。 + """ if self._config is None: user_llm_config = self._user_config.get('llmConfig', {}) - # 优先使用用户配置的provider,否则使用系统配置 + # 🔥 Provider 优先级:用户配置 > 环境变量 provider_str = user_llm_config.get('llmProvider') or getattr(settings, 'LLM_PROVIDER', 'openai') provider = self._parse_provider(provider_str) - # 获取API Key - 优先级:用户配置 > 系统通用配置 > 系统平台专属配置 + # 🔥 API Key 优先级:用户配置 > 环境变量通用配置 > 环境变量平台专属配置 api_key = ( user_llm_config.get('llmApiKey') or getattr(settings, 'LLM_API_KEY', '') or @@ -52,33 +60,33 @@ class LLMService: self._get_provider_api_key(provider) ) - # 获取Base URL + # 🔥 Base URL 优先级:用户配置 > 环境变量 base_url = ( user_llm_config.get('llmBaseUrl') or getattr(settings, 'LLM_BASE_URL', None) or self._get_provider_base_url(provider) ) - # 获取模型 + # 🔥 Model 优先级:用户配置 > 环境变量 > 默认模型 model = ( user_llm_config.get('llmModel') or getattr(settings, 'LLM_MODEL', '') or DEFAULT_MODELS.get(provider, 'gpt-4o-mini') ) - # 获取超时时间(用户配置是毫秒,系统配置是秒) + # 🔥 Timeout 优先级:用户配置(毫秒) > 环境变量(秒) timeout_ms = user_llm_config.get('llmTimeout') if timeout_ms: # 用户配置是毫秒,转换为秒 timeout = int(timeout_ms / 1000) if timeout_ms > 1000 else int(timeout_ms) else: - # 系统配置是秒 + # 环境变量是秒 timeout = int(getattr(settings, 'LLM_TIMEOUT', 150)) - # 获取温度 + # 🔥 Temperature 优先级:用户配置 > 环境变量 temperature = user_llm_config.get('llmTemperature') if user_llm_config.get('llmTemperature') is not None else float(getattr(settings, 'LLM_TEMPERATURE', 0.1)) - # 获取最大token数 + # 🔥 Max Tokens 优先级:用户配置 > 环境变量 max_tokens = user_llm_config.get('llmMaxTokens') or int(getattr(settings, 'LLM_MAX_TOKENS', 4096)) self._config = LLMConfig( @@ -394,6 +402,83 @@ Please analyze the following code: # 重新抛出异常,让调用者处理 raise + async def chat_completion_raw( + self, + messages: List[Dict[str, str]], + temperature: float = 0.1, + max_tokens: int = 4096, + ) -> Dict[str, Any]: + """ + 🔥 Agent 使用的原始聊天完成接口(兼容旧接口) + + Args: + messages: 消息列表,格式为 [{"role": "user", "content": "..."}] + temperature: 温度参数 + max_tokens: 最大token数 + + Returns: + 包含 content 和 usage 的字典 + """ + # 转换消息格式 + llm_messages = [ + LLMMessage(role=msg["role"], content=msg["content"]) + for msg in messages + ] + + request = LLMRequest( + messages=llm_messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + adapter = LLMFactory.create_adapter(self.config) + response = await adapter.complete(request) + + return { + "content": response.content, + "usage": { + "prompt_tokens": response.usage.prompt_tokens if response.usage else 0, + "completion_tokens": response.usage.completion_tokens if response.usage else 0, + "total_tokens": response.usage.total_tokens if response.usage else 0, + }, + } + + async def chat_completion_stream( + self, + messages: List[Dict[str, str]], + temperature: float = 0.1, + max_tokens: int = 4096, + ): + """ + 流式聊天完成接口,逐 token 返回 + + Args: + messages: 消息列表 + temperature: 温度参数 + max_tokens: 最大token数 + + Yields: + dict: {"type": "token", "content": str} 或 {"type": "done", ...} + """ + from .adapters.litellm_adapter import LiteLLMAdapter + + llm_messages = [ + LLMMessage(role=msg["role"], content=msg["content"]) + for msg in messages + ] + + request = LLMRequest( + messages=llm_messages, + temperature=temperature, + max_tokens=max_tokens, + ) + + # 使用 LiteLLM adapter 进行流式调用 + adapter = LiteLLMAdapter(self.config) + + async for chunk in adapter.stream_complete(request): + yield chunk + def _parse_json(self, text: str) -> Dict[str, Any]: """从LLM响应中解析JSON(增强版)""" diff --git a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin new file mode 100644 index 0000000..c95e62e Binary files /dev/null and b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin differ diff --git a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/header.bin b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/header.bin new file mode 100644 index 0000000..2349a18 Binary files /dev/null and b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/header.bin differ diff --git a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/length.bin b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/length.bin new file mode 100644 index 0000000..cb3e162 Binary files /dev/null and b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/length.bin differ diff --git a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/link_lists.bin b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/link_lists.bin new file mode 100644 index 0000000..e69de29 diff --git a/frontend/src/pages/AgentAudit.tsx b/frontend/src/pages/AgentAudit.tsx index aa911f2..94450c9 100644 --- a/frontend/src/pages/AgentAudit.tsx +++ b/frontend/src/pages/AgentAudit.tsx @@ -28,14 +28,12 @@ import { cancelAgentTask, } from "@/shared/api/agentTasks"; -// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件 +// 事件类型图标映射 const eventTypeIcons: Record = { - // 🧠 LLM 核心事件 - 最重要! + // LLM 核心事件 llm_start: , llm_thought: , llm_decision: , - llm_action: , - llm_observation: , llm_complete: , // 阶段相关 @@ -43,7 +41,7 @@ const eventTypeIcons: Record = { phase_complete: , thinking: , - // 工具相关 - LLM 决定的工具调用 + // 工具相关 tool_call: , tool_result: , tool_error: , @@ -65,14 +63,12 @@ const eventTypeIcons: Record = { task_cancel: , }; -// 事件类型颜色映射 - 🔥 LLM 事件突出显示 +// 事件类型颜色映射 const eventTypeColors: Record = { - // 🧠 LLM 核心事件 - 使用紫色系突出 + // LLM 核心事件 llm_start: "text-purple-400 font-semibold", - llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮 - llm_decision: "text-yellow-300 font-semibold", // 决策特别突出 - llm_action: "text-orange-300 font-medium", - llm_observation: "text-blue-300", + llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", + llm_decision: "text-yellow-300 font-semibold", llm_complete: "text-green-400 font-semibold", // 阶段相关 @@ -411,7 +407,7 @@ export default function AgentAuditPage() { {/* 左侧:执行日志 */}
- {/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */} + {/* LLM 思考过程展示区域 */} {(isThinking || thinking) && showThinking && (
- 🧠 LLM Thinking - Agent 的大脑正在工作 + LLM Thinking + Agent 思考过程
{isThinking && ( @@ -438,7 +434,7 @@ export default function AgentAuditPage() {
- {thinking || "🤔 正在思考下一步..."} + {thinking || "正在思考下一步..."} {isThinking && }
@@ -446,7 +442,7 @@ export default function AgentAuditPage() {
)} - {/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */} + {/* 工具调用展示区域 */} {toolCalls.length > 0 && showToolDetails && (
- 🔧 LLM Tool Calls - LLM 决定调用的工具 + Tool Calls + 工具调用记录
{toolCalls.length} 次调用 @@ -488,9 +484,9 @@ export default function AgentAuditPage() {
- LLM Execution Log + Execution Log
- LLM 思考 & 工具调用记录 + 执行日志 {(isStreaming || isStreamConnected) && ( @@ -708,7 +704,7 @@ function StatusBadge({ status }: { status: string }) { ); } -// 事件行组件 - 增强 LLM 事件展示 +// 事件行组件 function EventLine({ event }: { event: AgentEvent }) { const icon = eventTypeIcons[event.event_type] || ; const colorClass = eventTypeColors[event.event_type] || "text-gray-400"; @@ -717,19 +713,19 @@ function EventLine({ event }: { event: AgentEvent }) { ? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false }) : ""; - // LLM 思考事件特殊处理 - 展示多行内容 + // 特殊事件处理 const isLLMThought = event.event_type === "llm_thought"; const isLLMDecision = event.event_type === "llm_decision"; - const isLLMAction = event.event_type === "llm_action"; - const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction; + const isToolCall = event.event_type === "tool_call"; + const isToolResult = event.event_type === "tool_result"; - // LLM 事件背景色 + // 背景色 const bgClass = isLLMThought ? "bg-purple-950/40 border-l-2 border-purple-600" : isLLMDecision ? "bg-yellow-950/30 border-l-2 border-yellow-600" - : isLLMAction - ? "bg-orange-950/30 border-l-2 border-orange-600" + : isToolCall || isToolResult + ? "bg-gray-900/30" : ""; return ( @@ -738,7 +734,7 @@ function EventLine({ event }: { event: AgentEvent }) { {timestamp} {icon} - + {event.message} {event.tool_duration_ms && ( ({event.tool_duration_ms}ms)