feat(agent): 实现任务取消和超时处理机制

添加对Agent任务的取消和超时处理支持,包括:
- 在工具执行、子Agent运行和项目初始化阶段检查取消状态
- 为不同工具和Agent类型设置合理的超时时间
- 使用asyncio实现取消检查和超时控制
- 优化取消响应速度,减少资源浪费
This commit is contained in:
lintsinghua 2025-12-16 17:31:29 +08:00
parent a27d37960a
commit 5974323a71
4 changed files with 245 additions and 36 deletions

View File

@ -8,7 +8,7 @@ import json
import logging
import os
import shutil
from typing import Any, List, Optional, Dict
from typing import Any, List, Optional, Dict, Set
from datetime import datetime, timezone
from uuid import uuid4
@ -222,6 +222,13 @@ class TaskSummaryResponse(BaseModel):
_running_orchestrators: Dict[str, Any] = {}
# 运行中的事件管理器(用于 SSE 流)
_running_event_managers: Dict[str, EventManager] = {}
# 🔥 已取消的任务集合(用于前置操作的取消检查)
_cancelled_tasks: Set[str] = set()
def is_task_cancelled(task_id: str) -> bool:
"""检查任务是否已被取消"""
return task_id in _cancelled_tasks
async def _execute_agent_task(task_id: str):
@ -299,11 +306,16 @@ async def _execute_agent_task(task_id: str):
logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture")
# 🔥 获取项目根目录后检查取消
if is_task_cancelled(task_id):
logger.info(f"[Cancel] Task {task_id} cancelled after project preparation")
raise asyncio.CancelledError("任务已取消")
# 创建 LLM 服务
llm_service = LLMService(user_config=user_config)
# 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager
# 🔥 传递 event_emitter 以发送索引进度
# 🔥 传递 event_emitter 以发送索引进度,传递 task_id 以支持取消
tools = await _initialize_tools(
project_root,
llm_service,
@ -313,8 +325,14 @@ async def _execute_agent_task(task_id: str):
target_files=task.target_files,
project_id=str(project.id), # 🔥 传递 project_id 用于 RAG
event_emitter=event_emitter, # 🔥 新增
task_id=task_id, # 🔥 新增:用于取消检查
)
# 🔥 初始化工具后检查取消
if is_task_cancelled(task_id):
logger.info(f"[Cancel] Task {task_id} cancelled after tools initialization")
raise asyncio.CancelledError("任务已取消")
# 创建子 Agent
recon_agent = ReconAgent(
llm_service=llm_service,
@ -522,6 +540,7 @@ async def _execute_agent_task(task_id: str):
_running_tasks.pop(task_id, None)
_running_event_managers.pop(task_id, None)
_running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task
_cancelled_tasks.discard(task_id) # 🔥 清理取消标志
# 🔥 清理整个 Agent 注册表(包括所有子 Agent
agent_registry.clear()
@ -571,6 +590,7 @@ async def _initialize_tools(
target_files: Optional[List[str]] = None,
project_id: Optional[str] = None, # 🔥 用于 RAG collection_name
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
task_id: Optional[str] = None, # 🔥 新增:用于取消检查
) -> Dict[str, Dict[str, Any]]:
"""初始化工具集
@ -583,6 +603,7 @@ async def _initialize_tools(
target_files: 目标文件列表
project_id: 项目 ID用于 RAG collection_name
event_emitter: 事件发送器用于发送实时日志
task_id: 任务 ID用于取消检查
"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
@ -691,6 +712,11 @@ async def _initialize_tools(
exclude_patterns=exclude_patterns or [],
update_mode=IndexUpdateMode.SMART,
):
# 🔥 在索引过程中检查取消状态
if task_id and is_task_cancelled(task_id):
logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}")
raise asyncio.CancelledError("任务已取消")
index_progress = progress
# 每处理 10 个文件或有重要变化时发送进度更新
if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files:
@ -1509,14 +1535,18 @@ async def cancel_agent_task(
task = await db.get(AgentTask, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
project = await db.get(Project, task.project_id)
if not project or project.owner_id != current_user.id:
raise HTTPException(status_code=403, detail="无权操作此任务")
if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
raise HTTPException(status_code=400, detail="任务已结束,无法取消")
# 🔥 0. 立即标记任务为已取消(用于前置操作的取消检查)
_cancelled_tasks.add(task_id)
logger.info(f"[Cancel] Added task {task_id} to cancelled set")
# 🔥 1. 设置 Agent 的取消标志
runner = _running_tasks.get(task_id)
if runner:
@ -2053,6 +2083,11 @@ async def _get_project_root(
elif level == "error":
await event_emitter.emit_error(message)
# 🔥 辅助函数:检查取消状态
def check_cancelled():
if is_task_cancelled(task_id):
raise asyncio.CancelledError("任务已取消")
base_path = f"/tmp/deepaudit/{task_id}"
# 确保目录存在且为空
@ -2060,9 +2095,13 @@ async def _get_project_root(
shutil.rmtree(base_path)
os.makedirs(base_path, exist_ok=True)
# 🔥 在开始任何操作前检查取消
check_cancelled()
# 根据项目类型处理
if project.source_type == "zip":
# 🔥 ZIP 项目:解压 ZIP 文件
check_cancelled() # 🔥 解压前检查
await emit(f"📦 正在解压项目文件...")
from app.services.zip_storage import load_project_zip
@ -2070,8 +2109,14 @@ async def _get_project_root(
if zip_path and os.path.exists(zip_path):
try:
check_cancelled() # 🔥 解压前再次检查
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(base_path)
# 🔥 逐个文件解压,支持取消检查
file_list = zip_ref.namelist()
for i, file_name in enumerate(file_list):
if i % 50 == 0: # 每50个文件检查一次
check_cancelled()
zip_ref.extract(file_name, base_path)
logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}")
await emit(f"✅ ZIP 文件解压完成")
except Exception as e:
@ -2151,6 +2196,9 @@ async def _get_project_root(
last_error = ""
for branch in branches_to_try:
# 🔥 每次尝试前检查取消
check_cancelled()
# 清理目录(如果之前尝试失败)
if os.path.exists(base_path) and os.listdir(base_path):
shutil.rmtree(base_path)
@ -2158,13 +2206,30 @@ async def _get_project_root(
logger.info(f"🔄 Trying to clone repository (branch: {branch})...")
await emit(f"🔄 尝试克隆分支: {branch}")
# 🔥 使用 asyncio 包装 subprocess支持取消
try:
result = subprocess.run(
["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path],
capture_output=True,
text=True,
timeout=120, # 缩短超时时间
)
async def run_clone():
return await asyncio.to_thread(
subprocess.run,
["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path],
capture_output=True,
text=True,
timeout=120,
)
# 🔥 使用 wait_for 添加取消检查循环
clone_task = asyncio.create_task(run_clone())
while not clone_task.done():
check_cancelled()
try:
result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0)
break
except asyncio.TimeoutError:
continue
if clone_task.done():
result = clone_task.result()
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}")
@ -2179,9 +2244,13 @@ async def _get_project_root(
last_error = f"克隆分支 {branch} 超时"
logger.warning(last_error)
await emit(f"⚠️ 分支 {branch} 克隆超时,尝试其他分支...", "warning")
except asyncio.CancelledError:
logger.info(f"[Cancel] Git clone cancelled for task {task_id}")
raise
# 如果所有分支都失败,尝试不指定分支克隆(使用仓库默认分支)
if not clone_success:
check_cancelled() # 🔥 检查取消
logger.info(f"🔄 Trying to clone without specifying branch...")
await emit(f"🔄 尝试使用仓库默认分支克隆...")
if os.path.exists(base_path) and os.listdir(base_path):
@ -2189,12 +2258,27 @@ async def _get_project_root(
os.makedirs(base_path, exist_ok=True)
try:
result = subprocess.run(
["git", "clone", "--depth", "1", auth_url, base_path],
capture_output=True,
text=True,
timeout=120,
)
async def run_default_clone():
return await asyncio.to_thread(
subprocess.run,
["git", "clone", "--depth", "1", auth_url, base_path],
capture_output=True,
text=True,
timeout=120,
)
# 🔥 使用 wait_for 添加取消检查循环
clone_task = asyncio.create_task(run_default_clone())
while not clone_task.done():
check_cancelled()
try:
result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0)
break
except asyncio.TimeoutError:
continue
if clone_task.done():
result = clone_task.result()
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}")
@ -2205,6 +2289,9 @@ async def _get_project_root(
except subprocess.TimeoutExpired:
last_error = "克隆仓库超时"
await emit(f"⚠️ 克隆超时", "warning")
except asyncio.CancelledError:
logger.info(f"[Cancel] Git clone cancelled for task {task_id}")
raise
if not clone_success:
# 分析错误原因

View File

@ -1006,38 +1006,104 @@ class BaseAgent(ABC):
async def execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""
统一的工具执行方法
统一的工具执行方法 - 支持取消和超时
Args:
tool_name: 工具名称
tool_input: 工具参数
Returns:
工具执行结果字符串
"""
# 🔥 在执行工具前检查取消
if self.is_cancelled:
return "任务已取消"
return "⚠️ 任务已取消"
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
# 🔥 根据工具类型设置不同的超时时间
tool_timeouts = {
"semgrep_scan": 120, # 外部扫描工具需要更长时间
"bandit_scan": 90,
"gitleaks_scan": 60,
"npm_audit": 90,
"safety_scan": 60,
"kunlun_scan": 180,
"osv_scanner": 60,
"trufflehog_scan": 90,
"sandbox_exec": 60,
"php_test": 30,
"command_injection_test": 30,
"sql_injection_test": 30,
"xss_test": 30,
}
timeout = tool_timeouts.get(tool_name, 30) # 默认30秒
# 🔥 使用 asyncio.wait_for 添加超时控制,同时支持取消
async def execute_with_cancel_check():
"""包装工具执行,定期检查取消状态"""
# 创建工具执行任务
execute_task = asyncio.create_task(tool.execute(**tool_input))
try:
# 使用循环定期检查取消状态
while not execute_task.done():
if self.is_cancelled:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError("任务已取消")
# 等待任务完成或超时检查间隔
try:
return await asyncio.wait_for(
asyncio.shield(execute_task),
timeout=0.5 # 每0.5秒检查一次取消状态
)
except asyncio.TimeoutError:
continue # 继续循环检查
return await execute_task
except asyncio.CancelledError:
if not execute_task.done():
execute_task.cancel()
raise
try:
result = await asyncio.wait_for(
execute_with_cancel_check(),
timeout=timeout
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, f"超时 ({timeout}s)", duration_ms)
return f"⚠️ 工具 '{tool_name}' 执行超时 ({timeout}秒),请尝试其他方法或减小操作范围。"
except asyncio.CancelledError:
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, "已取消", duration_ms)
return "⚠️ 任务已取消"
duration_ms = int((time.time() - start) * 1000)
# 🔥 修复:确保传递有意义的结果字符串,避免 "None"
result_preview = str(result.data)[:200] if result.data is not None else (result.error[:200] if result.error else "")
await self.emit_tool_result(tool_name, result_preview, duration_ms)
# 🔥 工具执行后再次检查取消
if self.is_cancelled:
return "⚠️ 任务已取消"
if result.success:
output = str(result.data)
@ -1063,6 +1129,9 @@ class BaseAgent(ABC):
请根据错误信息调整参数或尝试其他方法"""
return error_msg
except asyncio.CancelledError:
logger.info(f"[{self.name}] Tool '{tool_name}' execution cancelled")
return "⚠️ 任务已取消"
except Exception as e:
import traceback
logger.error(f"Tool execution error: {e}")

View File

@ -647,14 +647,62 @@ Action Input: {{"参数": "值"}}
"project_root": self._runtime_context.get("project_root", "."),
"previous_results": previous_results,
}
# 🔥 执行子 Agent 前检查取消状态
if self.is_cancelled:
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
# 执行子 Agent
result = await agent.run(sub_input)
# 🔥 执行子 Agent - 支持取消和超时
# 设置子 Agent 超时(根据 Agent 类型)
agent_timeouts = {
"recon": 300, # 5 分钟
"analysis": 600, # 10 分钟
"verification": 300, # 5 分钟
}
timeout = agent_timeouts.get(agent_name, 300)
async def run_with_cancel_check():
"""包装子 Agent 执行,定期检查取消状态"""
run_task = asyncio.create_task(agent.run(sub_input))
try:
while not run_task.done():
if self.is_cancelled:
# 传播取消到子 Agent
if hasattr(agent, 'cancel'):
agent.cancel()
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError("任务已取消")
try:
return await asyncio.wait_for(
asyncio.shield(run_task),
timeout=1.0 # 每秒检查一次取消状态
)
except asyncio.TimeoutError:
continue
return await run_task
except asyncio.CancelledError:
if not run_task.done():
run_task.cancel()
raise
try:
result = await asyncio.wait_for(
run_with_cancel_check(),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(f"[{self.name}] Sub-agent {agent_name} timed out after {timeout}s")
return f"## {agent_name} Agent 执行超时\n\n子 Agent 执行超过 {timeout} 秒,已强制终止。请尝试更具体的任务或使用其他 Agent。"
except asyncio.CancelledError:
logger.info(f"[{self.name}] Sub-agent {agent_name} was cancelled")
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
# 🔥 执行后再次检查取消状态
if self.is_cancelled:
return f"## {agent_name} Agent 执行中断\n\n任务已被用户取消"

View File

@ -622,7 +622,12 @@ class VerificationAgent(BaseAgent):
# 成功调用,重置失败计数
if tool_call_key in self._failed_tool_calls:
del self._failed_tool_calls[tool_call_key]
# 🔥 工具执行后检查取消状态
if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled after tool execution")
break
step.observation = observation
# 🔥 发射 LLM 观察事件