From 5974323a7133fa5ff2419eaec032d83e1af31329 Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Tue, 16 Dec 2025 17:31:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(agent):=20=E5=AE=9E=E7=8E=B0=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=8F=96=E6=B6=88=E5=92=8C=E8=B6=85=E6=97=B6=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加对Agent任务的取消和超时处理支持,包括: - 在工具执行、子Agent运行和项目初始化阶段检查取消状态 - 为不同工具和Agent类型设置合理的超时时间 - 使用asyncio实现取消检查和超时控制 - 优化取消响应速度,减少资源浪费 --- backend/app/api/v1/endpoints/agent_tasks.py | 123 +++++++++++++++--- backend/app/services/agent/agents/base.py | 93 +++++++++++-- .../app/services/agent/agents/orchestrator.py | 58 ++++++++- .../app/services/agent/agents/verification.py | 7 +- 4 files changed, 245 insertions(+), 36 deletions(-) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index bfe8b87..04ed6d8 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -8,7 +8,7 @@ import json import logging import os import shutil -from typing import Any, List, Optional, Dict +from typing import Any, List, Optional, Dict, Set from datetime import datetime, timezone from uuid import uuid4 @@ -222,6 +222,13 @@ class TaskSummaryResponse(BaseModel): _running_orchestrators: Dict[str, Any] = {} # 运行中的事件管理器(用于 SSE 流) _running_event_managers: Dict[str, EventManager] = {} +# 🔥 已取消的任务集合(用于前置操作的取消检查) +_cancelled_tasks: Set[str] = set() + + +def is_task_cancelled(task_id: str) -> bool: + """检查任务是否已被取消""" + return task_id in _cancelled_tasks async def _execute_agent_task(task_id: str): @@ -299,11 +306,16 @@ async def _execute_agent_task(task_id: str): logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture") + # 🔥 获取项目根目录后检查取消 + if is_task_cancelled(task_id): + logger.info(f"[Cancel] Task {task_id} cancelled after project preparation") + raise asyncio.CancelledError("任务已取消") + # 创建 LLM 服务 llm_service = LLMService(user_config=user_config) # 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager - # 🔥 传递 event_emitter 以发送索引进度 + # 🔥 传递 event_emitter 以发送索引进度,传递 task_id 以支持取消 tools = await _initialize_tools( project_root, llm_service, @@ -313,8 +325,14 @@ async def _execute_agent_task(task_id: str): target_files=task.target_files, project_id=str(project.id), # 🔥 传递 project_id 用于 RAG event_emitter=event_emitter, # 🔥 新增 + task_id=task_id, # 🔥 新增:用于取消检查 ) + # 🔥 初始化工具后检查取消 + if is_task_cancelled(task_id): + logger.info(f"[Cancel] Task {task_id} cancelled after tools initialization") + raise asyncio.CancelledError("任务已取消") + # 创建子 Agent recon_agent = ReconAgent( llm_service=llm_service, @@ -522,6 +540,7 @@ async def _execute_agent_task(task_id: str): _running_tasks.pop(task_id, None) _running_event_managers.pop(task_id, None) _running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task + _cancelled_tasks.discard(task_id) # 🔥 清理取消标志 # 🔥 清理整个 Agent 注册表(包括所有子 Agent) agent_registry.clear() @@ -571,6 +590,7 @@ async def _initialize_tools( target_files: Optional[List[str]] = None, project_id: Optional[str] = None, # 🔥 用于 RAG collection_name event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志 + task_id: Optional[str] = None, # 🔥 新增:用于取消检查 ) -> Dict[str, Dict[str, Any]]: """初始化工具集 @@ -583,6 +603,7 @@ async def _initialize_tools( target_files: 目标文件列表 project_id: 项目 ID(用于 RAG collection_name) event_emitter: 事件发送器(用于发送实时日志) + task_id: 任务 ID(用于取消检查) """ from app.services.agent.tools import ( FileReadTool, FileSearchTool, ListFilesTool, @@ -691,6 +712,11 @@ async def _initialize_tools( exclude_patterns=exclude_patterns or [], update_mode=IndexUpdateMode.SMART, ): + # 🔥 在索引过程中检查取消状态 + if task_id and is_task_cancelled(task_id): + logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}") + raise asyncio.CancelledError("任务已取消") + index_progress = progress # 每处理 10 个文件或有重要变化时发送进度更新 if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files: @@ -1509,14 +1535,18 @@ async def cancel_agent_task( task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") - + project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此任务") - + if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: raise HTTPException(status_code=400, detail="任务已结束,无法取消") - + + # 🔥 0. 立即标记任务为已取消(用于前置操作的取消检查) + _cancelled_tasks.add(task_id) + logger.info(f"[Cancel] Added task {task_id} to cancelled set") + # 🔥 1. 设置 Agent 的取消标志 runner = _running_tasks.get(task_id) if runner: @@ -2053,6 +2083,11 @@ async def _get_project_root( elif level == "error": await event_emitter.emit_error(message) + # 🔥 辅助函数:检查取消状态 + def check_cancelled(): + if is_task_cancelled(task_id): + raise asyncio.CancelledError("任务已取消") + base_path = f"/tmp/deepaudit/{task_id}" # 确保目录存在且为空 @@ -2060,9 +2095,13 @@ async def _get_project_root( shutil.rmtree(base_path) os.makedirs(base_path, exist_ok=True) + # 🔥 在开始任何操作前检查取消 + check_cancelled() + # 根据项目类型处理 if project.source_type == "zip": # 🔥 ZIP 项目:解压 ZIP 文件 + check_cancelled() # 🔥 解压前检查 await emit(f"📦 正在解压项目文件...") from app.services.zip_storage import load_project_zip @@ -2070,8 +2109,14 @@ async def _get_project_root( if zip_path and os.path.exists(zip_path): try: + check_cancelled() # 🔥 解压前再次检查 with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(base_path) + # 🔥 逐个文件解压,支持取消检查 + file_list = zip_ref.namelist() + for i, file_name in enumerate(file_list): + if i % 50 == 0: # 每50个文件检查一次 + check_cancelled() + zip_ref.extract(file_name, base_path) logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}") await emit(f"✅ ZIP 文件解压完成") except Exception as e: @@ -2151,6 +2196,9 @@ async def _get_project_root( last_error = "" for branch in branches_to_try: + # 🔥 每次尝试前检查取消 + check_cancelled() + # 清理目录(如果之前尝试失败) if os.path.exists(base_path) and os.listdir(base_path): shutil.rmtree(base_path) @@ -2158,13 +2206,30 @@ async def _get_project_root( logger.info(f"🔄 Trying to clone repository (branch: {branch})...") await emit(f"🔄 尝试克隆分支: {branch}") + + # 🔥 使用 asyncio 包装 subprocess,支持取消 try: - result = subprocess.run( - ["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path], - capture_output=True, - text=True, - timeout=120, # 缩短超时时间 - ) + async def run_clone(): + return await asyncio.to_thread( + subprocess.run, + ["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path], + capture_output=True, + text=True, + timeout=120, + ) + + # 🔥 使用 wait_for 添加取消检查循环 + clone_task = asyncio.create_task(run_clone()) + while not clone_task.done(): + check_cancelled() + try: + result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0) + break + except asyncio.TimeoutError: + continue + + if clone_task.done(): + result = clone_task.result() if result.returncode == 0: logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}") @@ -2179,9 +2244,13 @@ async def _get_project_root( last_error = f"克隆分支 {branch} 超时" logger.warning(last_error) await emit(f"⚠️ 分支 {branch} 克隆超时,尝试其他分支...", "warning") + except asyncio.CancelledError: + logger.info(f"[Cancel] Git clone cancelled for task {task_id}") + raise # 如果所有分支都失败,尝试不指定分支克隆(使用仓库默认分支) if not clone_success: + check_cancelled() # 🔥 检查取消 logger.info(f"🔄 Trying to clone without specifying branch...") await emit(f"🔄 尝试使用仓库默认分支克隆...") if os.path.exists(base_path) and os.listdir(base_path): @@ -2189,12 +2258,27 @@ async def _get_project_root( os.makedirs(base_path, exist_ok=True) try: - result = subprocess.run( - ["git", "clone", "--depth", "1", auth_url, base_path], - capture_output=True, - text=True, - timeout=120, - ) + async def run_default_clone(): + return await asyncio.to_thread( + subprocess.run, + ["git", "clone", "--depth", "1", auth_url, base_path], + capture_output=True, + text=True, + timeout=120, + ) + + # 🔥 使用 wait_for 添加取消检查循环 + clone_task = asyncio.create_task(run_default_clone()) + while not clone_task.done(): + check_cancelled() + try: + result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0) + break + except asyncio.TimeoutError: + continue + + if clone_task.done(): + result = clone_task.result() if result.returncode == 0: logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}") @@ -2205,6 +2289,9 @@ async def _get_project_root( except subprocess.TimeoutExpired: last_error = "克隆仓库超时" await emit(f"⚠️ 克隆超时", "warning") + except asyncio.CancelledError: + logger.info(f"[Cancel] Git clone cancelled for task {task_id}") + raise if not clone_success: # 分析错误原因 diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index 9b0ced6..a198374 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -1006,38 +1006,104 @@ class BaseAgent(ABC): async def execute_tool(self, tool_name: str, tool_input: Dict) -> str: """ - 统一的工具执行方法 - + 统一的工具执行方法 - 支持取消和超时 + Args: tool_name: 工具名称 tool_input: 工具参数 - + Returns: 工具执行结果字符串 """ # 🔥 在执行工具前检查取消 if self.is_cancelled: - return "任务已取消" - + return "⚠️ 任务已取消" + tool = self.tools.get(tool_name) - + if not tool: return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" - + try: self._tool_calls += 1 await self.emit_tool_call(tool_name, tool_input) - + import time start = time.time() - - result = await tool.execute(**tool_input) - + + # 🔥 根据工具类型设置不同的超时时间 + tool_timeouts = { + "semgrep_scan": 120, # 外部扫描工具需要更长时间 + "bandit_scan": 90, + "gitleaks_scan": 60, + "npm_audit": 90, + "safety_scan": 60, + "kunlun_scan": 180, + "osv_scanner": 60, + "trufflehog_scan": 90, + "sandbox_exec": 60, + "php_test": 30, + "command_injection_test": 30, + "sql_injection_test": 30, + "xss_test": 30, + } + timeout = tool_timeouts.get(tool_name, 30) # 默认30秒 + + # 🔥 使用 asyncio.wait_for 添加超时控制,同时支持取消 + async def execute_with_cancel_check(): + """包装工具执行,定期检查取消状态""" + # 创建工具执行任务 + execute_task = asyncio.create_task(tool.execute(**tool_input)) + + try: + # 使用循环定期检查取消状态 + while not execute_task.done(): + if self.is_cancelled: + execute_task.cancel() + try: + await execute_task + except asyncio.CancelledError: + pass + raise asyncio.CancelledError("任务已取消") + + # 等待任务完成或超时检查间隔 + try: + return await asyncio.wait_for( + asyncio.shield(execute_task), + timeout=0.5 # 每0.5秒检查一次取消状态 + ) + except asyncio.TimeoutError: + continue # 继续循环检查 + + return await execute_task + except asyncio.CancelledError: + if not execute_task.done(): + execute_task.cancel() + raise + + try: + result = await asyncio.wait_for( + execute_with_cancel_check(), + timeout=timeout + ) + except asyncio.TimeoutError: + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, f"超时 ({timeout}s)", duration_ms) + return f"⚠️ 工具 '{tool_name}' 执行超时 ({timeout}秒),请尝试其他方法或减小操作范围。" + except asyncio.CancelledError: + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, "已取消", duration_ms) + return "⚠️ 任务已取消" + duration_ms = int((time.time() - start) * 1000) # 🔥 修复:确保传递有意义的结果字符串,避免 "None" result_preview = str(result.data)[:200] if result.data is not None else (result.error[:200] if result.error else "") await self.emit_tool_result(tool_name, result_preview, duration_ms) - + + # 🔥 工具执行后再次检查取消 + if self.is_cancelled: + return "⚠️ 任务已取消" + if result.success: output = str(result.data) @@ -1063,6 +1129,9 @@ class BaseAgent(ABC): 请根据错误信息调整参数或尝试其他方法。""" return error_msg + except asyncio.CancelledError: + logger.info(f"[{self.name}] Tool '{tool_name}' execution cancelled") + return "⚠️ 任务已取消" except Exception as e: import traceback logger.error(f"Tool execution error: {e}") diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index 24d6df0..b99973f 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -647,14 +647,62 @@ Action Input: {{"参数": "值"}} "project_root": self._runtime_context.get("project_root", "."), "previous_results": previous_results, } - + # 🔥 执行子 Agent 前检查取消状态 if self.is_cancelled: return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消" - - # 执行子 Agent - result = await agent.run(sub_input) - + + # 🔥 执行子 Agent - 支持取消和超时 + # 设置子 Agent 超时(根据 Agent 类型) + agent_timeouts = { + "recon": 300, # 5 分钟 + "analysis": 600, # 10 分钟 + "verification": 300, # 5 分钟 + } + timeout = agent_timeouts.get(agent_name, 300) + + async def run_with_cancel_check(): + """包装子 Agent 执行,定期检查取消状态""" + run_task = asyncio.create_task(agent.run(sub_input)) + try: + while not run_task.done(): + if self.is_cancelled: + # 传播取消到子 Agent + if hasattr(agent, 'cancel'): + agent.cancel() + run_task.cancel() + try: + await run_task + except asyncio.CancelledError: + pass + raise asyncio.CancelledError("任务已取消") + + try: + return await asyncio.wait_for( + asyncio.shield(run_task), + timeout=1.0 # 每秒检查一次取消状态 + ) + except asyncio.TimeoutError: + continue + + return await run_task + except asyncio.CancelledError: + if not run_task.done(): + run_task.cancel() + raise + + try: + result = await asyncio.wait_for( + run_with_cancel_check(), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"[{self.name}] Sub-agent {agent_name} timed out after {timeout}s") + return f"## {agent_name} Agent 执行超时\n\n子 Agent 执行超过 {timeout} 秒,已强制终止。请尝试更具体的任务或使用其他 Agent。" + except asyncio.CancelledError: + logger.info(f"[{self.name}] Sub-agent {agent_name} was cancelled") + return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消" + # 🔥 执行后再次检查取消状态 if self.is_cancelled: return f"## {agent_name} Agent 执行中断\n\n任务已被用户取消" diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py index dc32e05..c9206e9 100644 --- a/backend/app/services/agent/agents/verification.py +++ b/backend/app/services/agent/agents/verification.py @@ -622,7 +622,12 @@ class VerificationAgent(BaseAgent): # 成功调用,重置失败计数 if tool_call_key in self._failed_tool_calls: del self._failed_tool_calls[tool_call_key] - + + # 🔥 工具执行后检查取消状态 + if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled after tool execution") + break + step.observation = observation # 🔥 发射 LLM 观察事件