feat(agent): 实现任务取消和超时处理机制
添加对Agent任务的取消和超时处理支持,包括: - 在工具执行、子Agent运行和项目初始化阶段检查取消状态 - 为不同工具和Agent类型设置合理的超时时间 - 使用asyncio实现取消检查和超时控制 - 优化取消响应速度,减少资源浪费
This commit is contained in:
parent
a27d37960a
commit
5974323a71
|
|
@ -8,7 +8,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, List, Optional, Dict
|
||||
from typing import Any, List, Optional, Dict, Set
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -222,6 +222,13 @@ class TaskSummaryResponse(BaseModel):
|
|||
_running_orchestrators: Dict[str, Any] = {}
|
||||
# 运行中的事件管理器(用于 SSE 流)
|
||||
_running_event_managers: Dict[str, EventManager] = {}
|
||||
# 🔥 已取消的任务集合(用于前置操作的取消检查)
|
||||
_cancelled_tasks: Set[str] = set()
|
||||
|
||||
|
||||
def is_task_cancelled(task_id: str) -> bool:
|
||||
"""检查任务是否已被取消"""
|
||||
return task_id in _cancelled_tasks
|
||||
|
||||
|
||||
async def _execute_agent_task(task_id: str):
|
||||
|
|
@ -299,11 +306,16 @@ async def _execute_agent_task(task_id: str):
|
|||
|
||||
logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture")
|
||||
|
||||
# 🔥 获取项目根目录后检查取消
|
||||
if is_task_cancelled(task_id):
|
||||
logger.info(f"[Cancel] Task {task_id} cancelled after project preparation")
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
# 创建 LLM 服务
|
||||
llm_service = LLMService(user_config=user_config)
|
||||
|
||||
# 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager
|
||||
# 🔥 传递 event_emitter 以发送索引进度
|
||||
# 🔥 传递 event_emitter 以发送索引进度,传递 task_id 以支持取消
|
||||
tools = await _initialize_tools(
|
||||
project_root,
|
||||
llm_service,
|
||||
|
|
@ -313,8 +325,14 @@ async def _execute_agent_task(task_id: str):
|
|||
target_files=task.target_files,
|
||||
project_id=str(project.id), # 🔥 传递 project_id 用于 RAG
|
||||
event_emitter=event_emitter, # 🔥 新增
|
||||
task_id=task_id, # 🔥 新增:用于取消检查
|
||||
)
|
||||
|
||||
# 🔥 初始化工具后检查取消
|
||||
if is_task_cancelled(task_id):
|
||||
logger.info(f"[Cancel] Task {task_id} cancelled after tools initialization")
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
# 创建子 Agent
|
||||
recon_agent = ReconAgent(
|
||||
llm_service=llm_service,
|
||||
|
|
@ -522,6 +540,7 @@ async def _execute_agent_task(task_id: str):
|
|||
_running_tasks.pop(task_id, None)
|
||||
_running_event_managers.pop(task_id, None)
|
||||
_running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task
|
||||
_cancelled_tasks.discard(task_id) # 🔥 清理取消标志
|
||||
|
||||
# 🔥 清理整个 Agent 注册表(包括所有子 Agent)
|
||||
agent_registry.clear()
|
||||
|
|
@ -571,6 +590,7 @@ async def _initialize_tools(
|
|||
target_files: Optional[List[str]] = None,
|
||||
project_id: Optional[str] = None, # 🔥 用于 RAG collection_name
|
||||
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
|
||||
task_id: Optional[str] = None, # 🔥 新增:用于取消检查
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""初始化工具集
|
||||
|
||||
|
|
@ -583,6 +603,7 @@ async def _initialize_tools(
|
|||
target_files: 目标文件列表
|
||||
project_id: 项目 ID(用于 RAG collection_name)
|
||||
event_emitter: 事件发送器(用于发送实时日志)
|
||||
task_id: 任务 ID(用于取消检查)
|
||||
"""
|
||||
from app.services.agent.tools import (
|
||||
FileReadTool, FileSearchTool, ListFilesTool,
|
||||
|
|
@ -691,6 +712,11 @@ async def _initialize_tools(
|
|||
exclude_patterns=exclude_patterns or [],
|
||||
update_mode=IndexUpdateMode.SMART,
|
||||
):
|
||||
# 🔥 在索引过程中检查取消状态
|
||||
if task_id and is_task_cancelled(task_id):
|
||||
logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}")
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
index_progress = progress
|
||||
# 每处理 10 个文件或有重要变化时发送进度更新
|
||||
if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files:
|
||||
|
|
@ -1517,6 +1543,10 @@ async def cancel_agent_task(
|
|||
if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
|
||||
raise HTTPException(status_code=400, detail="任务已结束,无法取消")
|
||||
|
||||
# 🔥 0. 立即标记任务为已取消(用于前置操作的取消检查)
|
||||
_cancelled_tasks.add(task_id)
|
||||
logger.info(f"[Cancel] Added task {task_id} to cancelled set")
|
||||
|
||||
# 🔥 1. 设置 Agent 的取消标志
|
||||
runner = _running_tasks.get(task_id)
|
||||
if runner:
|
||||
|
|
@ -2053,6 +2083,11 @@ async def _get_project_root(
|
|||
elif level == "error":
|
||||
await event_emitter.emit_error(message)
|
||||
|
||||
# 🔥 辅助函数:检查取消状态
|
||||
def check_cancelled():
|
||||
if is_task_cancelled(task_id):
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
base_path = f"/tmp/deepaudit/{task_id}"
|
||||
|
||||
# 确保目录存在且为空
|
||||
|
|
@ -2060,9 +2095,13 @@ async def _get_project_root(
|
|||
shutil.rmtree(base_path)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
# 🔥 在开始任何操作前检查取消
|
||||
check_cancelled()
|
||||
|
||||
# 根据项目类型处理
|
||||
if project.source_type == "zip":
|
||||
# 🔥 ZIP 项目:解压 ZIP 文件
|
||||
check_cancelled() # 🔥 解压前检查
|
||||
await emit(f"📦 正在解压项目文件...")
|
||||
from app.services.zip_storage import load_project_zip
|
||||
|
||||
|
|
@ -2070,8 +2109,14 @@ async def _get_project_root(
|
|||
|
||||
if zip_path and os.path.exists(zip_path):
|
||||
try:
|
||||
check_cancelled() # 🔥 解压前再次检查
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(base_path)
|
||||
# 🔥 逐个文件解压,支持取消检查
|
||||
file_list = zip_ref.namelist()
|
||||
for i, file_name in enumerate(file_list):
|
||||
if i % 50 == 0: # 每50个文件检查一次
|
||||
check_cancelled()
|
||||
zip_ref.extract(file_name, base_path)
|
||||
logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}")
|
||||
await emit(f"✅ ZIP 文件解压完成")
|
||||
except Exception as e:
|
||||
|
|
@ -2151,6 +2196,9 @@ async def _get_project_root(
|
|||
last_error = ""
|
||||
|
||||
for branch in branches_to_try:
|
||||
# 🔥 每次尝试前检查取消
|
||||
check_cancelled()
|
||||
|
||||
# 清理目录(如果之前尝试失败)
|
||||
if os.path.exists(base_path) and os.listdir(base_path):
|
||||
shutil.rmtree(base_path)
|
||||
|
|
@ -2158,14 +2206,31 @@ async def _get_project_root(
|
|||
|
||||
logger.info(f"🔄 Trying to clone repository (branch: {branch})...")
|
||||
await emit(f"🔄 尝试克隆分支: {branch}")
|
||||
|
||||
# 🔥 使用 asyncio 包装 subprocess,支持取消
|
||||
try:
|
||||
result = subprocess.run(
|
||||
async def run_clone():
|
||||
return await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120, # 缩短超时时间
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# 🔥 使用 wait_for 添加取消检查循环
|
||||
clone_task = asyncio.create_task(run_clone())
|
||||
while not clone_task.done():
|
||||
check_cancelled()
|
||||
try:
|
||||
result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if clone_task.done():
|
||||
result = clone_task.result()
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}")
|
||||
await emit(f"✅ 仓库克隆成功 (分支: {branch})")
|
||||
|
|
@ -2179,9 +2244,13 @@ async def _get_project_root(
|
|||
last_error = f"克隆分支 {branch} 超时"
|
||||
logger.warning(last_error)
|
||||
await emit(f"⚠️ 分支 {branch} 克隆超时,尝试其他分支...", "warning")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[Cancel] Git clone cancelled for task {task_id}")
|
||||
raise
|
||||
|
||||
# 如果所有分支都失败,尝试不指定分支克隆(使用仓库默认分支)
|
||||
if not clone_success:
|
||||
check_cancelled() # 🔥 检查取消
|
||||
logger.info(f"🔄 Trying to clone without specifying branch...")
|
||||
await emit(f"🔄 尝试使用仓库默认分支克隆...")
|
||||
if os.path.exists(base_path) and os.listdir(base_path):
|
||||
|
|
@ -2189,13 +2258,28 @@ async def _get_project_root(
|
|||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
async def run_default_clone():
|
||||
return await asyncio.to_thread(
|
||||
subprocess.run,
|
||||
["git", "clone", "--depth", "1", auth_url, base_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# 🔥 使用 wait_for 添加取消检查循环
|
||||
clone_task = asyncio.create_task(run_default_clone())
|
||||
while not clone_task.done():
|
||||
check_cancelled()
|
||||
try:
|
||||
result = await asyncio.wait_for(asyncio.shield(clone_task), timeout=1.0)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if clone_task.done():
|
||||
result = clone_task.result()
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}")
|
||||
await emit(f"✅ 仓库克隆成功 (默认分支)")
|
||||
|
|
@ -2205,6 +2289,9 @@ async def _get_project_root(
|
|||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆仓库超时"
|
||||
await emit(f"⚠️ 克隆超时", "warning")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[Cancel] Git clone cancelled for task {task_id}")
|
||||
raise
|
||||
|
||||
if not clone_success:
|
||||
# 分析错误原因
|
||||
|
|
|
|||
|
|
@ -1006,7 +1006,7 @@ class BaseAgent(ABC):
|
|||
|
||||
async def execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""
|
||||
统一的工具执行方法
|
||||
统一的工具执行方法 - 支持取消和超时
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
|
@ -1017,7 +1017,7 @@ class BaseAgent(ABC):
|
|||
"""
|
||||
# 🔥 在执行工具前检查取消
|
||||
if self.is_cancelled:
|
||||
return "任务已取消"
|
||||
return "⚠️ 任务已取消"
|
||||
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
|
|
@ -1031,13 +1031,79 @@ class BaseAgent(ABC):
|
|||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
# 🔥 根据工具类型设置不同的超时时间
|
||||
tool_timeouts = {
|
||||
"semgrep_scan": 120, # 外部扫描工具需要更长时间
|
||||
"bandit_scan": 90,
|
||||
"gitleaks_scan": 60,
|
||||
"npm_audit": 90,
|
||||
"safety_scan": 60,
|
||||
"kunlun_scan": 180,
|
||||
"osv_scanner": 60,
|
||||
"trufflehog_scan": 90,
|
||||
"sandbox_exec": 60,
|
||||
"php_test": 30,
|
||||
"command_injection_test": 30,
|
||||
"sql_injection_test": 30,
|
||||
"xss_test": 30,
|
||||
}
|
||||
timeout = tool_timeouts.get(tool_name, 30) # 默认30秒
|
||||
|
||||
# 🔥 使用 asyncio.wait_for 添加超时控制,同时支持取消
|
||||
async def execute_with_cancel_check():
|
||||
"""包装工具执行,定期检查取消状态"""
|
||||
# 创建工具执行任务
|
||||
execute_task = asyncio.create_task(tool.execute(**tool_input))
|
||||
|
||||
try:
|
||||
# 使用循环定期检查取消状态
|
||||
while not execute_task.done():
|
||||
if self.is_cancelled:
|
||||
execute_task.cancel()
|
||||
try:
|
||||
await execute_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
# 等待任务完成或超时检查间隔
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(execute_task),
|
||||
timeout=0.5 # 每0.5秒检查一次取消状态
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue # 继续循环检查
|
||||
|
||||
return await execute_task
|
||||
except asyncio.CancelledError:
|
||||
if not execute_task.done():
|
||||
execute_task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
execute_with_cancel_check(),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, f"超时 ({timeout}s)", duration_ms)
|
||||
return f"⚠️ 工具 '{tool_name}' 执行超时 ({timeout}秒),请尝试其他方法或减小操作范围。"
|
||||
except asyncio.CancelledError:
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, "已取消", duration_ms)
|
||||
return "⚠️ 任务已取消"
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
# 🔥 修复:确保传递有意义的结果字符串,避免 "None"
|
||||
result_preview = str(result.data)[:200] if result.data is not None else (result.error[:200] if result.error else "")
|
||||
await self.emit_tool_result(tool_name, result_preview, duration_ms)
|
||||
|
||||
# 🔥 工具执行后再次检查取消
|
||||
if self.is_cancelled:
|
||||
return "⚠️ 任务已取消"
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
|
||||
|
|
@ -1063,6 +1129,9 @@ class BaseAgent(ABC):
|
|||
请根据错误信息调整参数或尝试其他方法。"""
|
||||
return error_msg
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] Tool '{tool_name}' execution cancelled")
|
||||
return "⚠️ 任务已取消"
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
|
|
|
|||
|
|
@ -652,8 +652,56 @@ Action Input: {{"参数": "值"}}
|
|||
if self.is_cancelled:
|
||||
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
|
||||
|
||||
# 执行子 Agent
|
||||
result = await agent.run(sub_input)
|
||||
# 🔥 执行子 Agent - 支持取消和超时
|
||||
# 设置子 Agent 超时(根据 Agent 类型)
|
||||
agent_timeouts = {
|
||||
"recon": 300, # 5 分钟
|
||||
"analysis": 600, # 10 分钟
|
||||
"verification": 300, # 5 分钟
|
||||
}
|
||||
timeout = agent_timeouts.get(agent_name, 300)
|
||||
|
||||
async def run_with_cancel_check():
|
||||
"""包装子 Agent 执行,定期检查取消状态"""
|
||||
run_task = asyncio.create_task(agent.run(sub_input))
|
||||
try:
|
||||
while not run_task.done():
|
||||
if self.is_cancelled:
|
||||
# 传播取消到子 Agent
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(run_task),
|
||||
timeout=1.0 # 每秒检查一次取消状态
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
return await run_task
|
||||
except asyncio.CancelledError:
|
||||
if not run_task.done():
|
||||
run_task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
run_with_cancel_check(),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{self.name}] Sub-agent {agent_name} timed out after {timeout}s")
|
||||
return f"## {agent_name} Agent 执行超时\n\n子 Agent 执行超过 {timeout} 秒,已强制终止。请尝试更具体的任务或使用其他 Agent。"
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] Sub-agent {agent_name} was cancelled")
|
||||
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
|
||||
|
||||
# 🔥 执行后再次检查取消状态
|
||||
if self.is_cancelled:
|
||||
|
|
|
|||
|
|
@ -623,6 +623,11 @@ class VerificationAgent(BaseAgent):
|
|||
if tool_call_key in self._failed_tool_calls:
|
||||
del self._failed_tool_calls[tool_call_key]
|
||||
|
||||
# 🔥 工具执行后检查取消状态
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled after tool execution")
|
||||
break
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射 LLM 观察事件
|
||||
|
|
|
|||
Loading…
Reference in New Issue