feat(agent): 增强任务执行流程和实时日志反馈

- 在任务执行流程中添加实时事件反馈,包括克隆进度和索引进度
- 实现 RAG v2.0 智能索引功能,支持模型变更检测和增量更新
- 改进文件工具兼容性,支持 path 参数作为 directory 别名
- 扩展前端事件处理逻辑,支持更多事件类型显示
- 修复 tree-sitter 版本兼容性问题
This commit is contained in:
lintsinghua 2025-12-16 16:56:09 +08:00
parent a31372450e
commit e2109647bf
9 changed files with 1101 additions and 223 deletions

View File

@ -244,24 +244,40 @@ async def _execute_agent_task(task_id: str):
sandbox_manager = SandboxManager()
await sandbox_manager.initialize()
logger.info(f"🐳 Global Sandbox Manager initialized (Available: {sandbox_manager.is_available})")
# 🔥 提前创建事件管理器,以便在克隆仓库和索引时发送实时日志
from app.services.agent.event_manager import EventManager, AgentEventEmitter
event_manager = EventManager(db_session_factory=async_session_factory)
event_manager.create_queue(task_id)
event_emitter = AgentEventEmitter(task_id, event_manager)
_running_event_managers[task_id] = event_manager
async with async_session_factory() as db:
orchestrator = None
start_time = time.time()
try:
# 获取任务
task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)])
if not task:
logger.error(f"Task {task_id} not found")
return
# 获取项目
project = task.project
if not project:
logger.error(f"Project not found for task {task_id}")
return
# 🔥 发送任务开始事件 - 使用 phase_start 让前端知道进入准备阶段
await event_emitter.emit_phase_start("preparation", f"🚀 任务开始执行: {project.name}")
# 更新任务阶段为准备中
task.status = AgentTaskStatus.RUNNING
task.started_at = datetime.now(timezone.utc)
task.current_phase = AgentTaskPhase.PLANNING # preparation 对应 PLANNING
await db.commit()
# 获取用户配置(需要在获取项目根目录之前,以便传递 token
user_config = await _get_user_config(db, task.created_by)
@ -271,30 +287,23 @@ async def _execute_agent_task(task_id: str):
gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN
# 获取项目根目录(传递任务指定的分支和认证 token
# 🔥 传递 event_emitter 以发送克隆进度
project_root = await _get_project_root(
project,
task_id,
task.branch_name,
github_token=github_token,
gitlab_token=gitlab_token,
event_emitter=event_emitter, # 🔥 新增
)
# 更新状态为运行中
task.status = AgentTaskStatus.RUNNING
task.started_at = datetime.now(timezone.utc)
task.current_phase = AgentTaskPhase.PLANNING
await db.commit()
logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture")
# 创建事件管理器
event_manager = EventManager(db_session_factory=async_session_factory)
event_manager.create_queue(task_id)
event_emitter = AgentEventEmitter(task_id, event_manager)
# 创建 LLM 服务
llm_service = LLMService(user_config=user_config)
# 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager
# 🔥 传递 event_emitter 以发送索引进度
tools = await _initialize_tools(
project_root,
llm_service,
@ -303,27 +312,28 @@ async def _execute_agent_task(task_id: str):
exclude_patterns=task.exclude_patterns,
target_files=task.target_files,
project_id=str(project.id), # 🔥 传递 project_id 用于 RAG
event_emitter=event_emitter, # 🔥 新增
)
# 创建子 Agent
recon_agent = ReconAgent(
llm_service=llm_service,
tools=tools.get("recon", {}),
event_emitter=event_emitter,
)
analysis_agent = AnalysisAgent(
llm_service=llm_service,
tools=tools.get("analysis", {}),
event_emitter=event_emitter,
)
verification_agent = VerificationAgent(
llm_service=llm_service,
tools=tools.get("verification", {}),
event_emitter=event_emitter,
)
# 创建 Orchestrator Agent
orchestrator = OrchestratorAgent(
llm_service=llm_service,
@ -335,7 +345,7 @@ async def _execute_agent_task(task_id: str):
"verification": verification_agent,
},
)
# 注册到全局
_running_orchestrators[task_id] = orchestrator
_running_tasks[task_id] = orchestrator # 兼容旧的取消逻辑
@ -560,6 +570,7 @@ async def _initialize_tools(
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
project_id: Optional[str] = None, # 🔥 用于 RAG collection_name
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
) -> Dict[str, Dict[str, Any]]:
"""初始化工具集
@ -571,6 +582,7 @@ async def _initialize_tools(
exclude_patterns: 排除模式列表
target_files: 目标文件列表
project_id: 项目 ID用于 RAG collection_name
event_emitter: 事件发送器用于发送实时日志
"""
from app.services.agent.tools import (
FileReadTool, FileSearchTool, ListFilesTool,
@ -588,12 +600,27 @@ async def _initialize_tools(
GetVulnerabilityKnowledgeTool,
)
# 🔥 RAG 相关导入
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService
from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService, IndexUpdateMode
from app.core.config import settings
# 辅助函数:发送事件
async def emit(message: str, level: str = "info"):
if event_emitter:
logger.debug(f"[EMIT-TOOLS] Sending {level}: {message[:60]}...")
if level == "info":
await event_emitter.emit_info(message)
elif level == "warning":
await event_emitter.emit_warning(message)
elif level == "error":
await event_emitter.emit_error(message)
else:
logger.warning(f"[EMIT-TOOLS] No event_emitter, skipping: {message[:60]}...")
# ============ 🔥 初始化 RAG 系统 ============
retriever = None
try:
await emit(f"🔍 正在初始化 RAG 系统...")
# 从用户配置中获取 embedding 配置
user_llm_config = (user_config or {}).get('llmConfig', {})
user_other_config = (user_config or {}).get('otherConfig', {})
@ -631,6 +658,7 @@ async def _initialize_tools(
)
logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}, base_url={embedding_base_url or '(使用默认)'}")
await emit(f"📊 Embedding 配置: {embedding_provider}/{embedding_model}")
# 创建 Embedding 服务
embedding_service = EmbeddingService(
@ -643,6 +671,47 @@ async def _initialize_tools(
# 创建 collection_name基于 project_id
collection_name = f"project_{project_id}" if project_id else "default_project"
# 🔥 v2.0: 创建 CodeIndexer 并进行智能索引
# 智能索引会自动:
# - 检测 embedding 模型变更,如需要则自动重建
# - 对比文件 hash只更新变化的文件增量更新
indexer = CodeIndexer(
collection_name=collection_name,
embedding_service=embedding_service,
persist_directory=settings.VECTOR_DB_PATH,
)
logger.info(f"📝 开始智能索引项目: {project_root}")
await emit(f"📝 正在构建代码向量索引...")
index_progress = None
last_progress_update = 0
async for progress in indexer.smart_index_directory(
directory=project_root,
exclude_patterns=exclude_patterns or [],
update_mode=IndexUpdateMode.SMART,
):
index_progress = progress
# 每处理 10 个文件或有重要变化时发送进度更新
if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files:
if progress.total_files > 0:
await emit(
f"📝 索引进度: {progress.processed_files}/{progress.total_files} 文件 "
f"({progress.progress_percentage:.0f}%)"
)
last_progress_update = progress.processed_files
if index_progress:
summary = (
f"✅ 索引完成: 模式={index_progress.update_mode}, "
f"新增={index_progress.added_files}, "
f"更新={index_progress.updated_files}, "
f"删除={index_progress.deleted_files}, "
f"代码块={index_progress.indexed_chunks}"
)
logger.info(summary)
await emit(summary)
# 创建 CodeRetriever用于搜索
# 🔥 传递 api_key用于自动适配 collection 的 embedding 配置
retriever = CodeRetriever(
@ -653,9 +722,13 @@ async def _initialize_tools(
)
logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}")
await emit(f"✅ RAG 系统初始化成功")
except Exception as e:
logger.warning(f"⚠️ RAG 系统初始化失败: {e}")
await emit(f"⚠️ RAG 系统初始化失败: {e}", "warning")
import traceback
logger.debug(f"RAG 初始化异常详情:\n{traceback.format_exc()}")
retriever = None
# 基础工具 - 传递排除模式和目标文件
@ -1942,6 +2015,7 @@ async def _get_project_root(
branch_name: Optional[str] = None,
github_token: Optional[str] = None,
gitlab_token: Optional[str] = None,
event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志
) -> str:
"""
获取项目根目录
@ -1956,6 +2030,7 @@ async def _get_project_root(
branch_name: 分支名称仓库项目使用优先于 project.default_branch
github_token: GitHub 访问令牌用于私有仓库
gitlab_token: GitLab 访问令牌用于私有仓库
event_emitter: 事件发送器用于发送实时日志
Returns:
项目根目录路径
@ -1968,6 +2043,16 @@ async def _get_project_root(
import shutil
from urllib.parse import urlparse, urlunparse
# 辅助函数:发送事件
async def emit(message: str, level: str = "info"):
if event_emitter:
if level == "info":
await event_emitter.emit_info(message)
elif level == "warning":
await event_emitter.emit_warning(message)
elif level == "error":
await event_emitter.emit_error(message)
base_path = f"/tmp/deepaudit/{task_id}"
# 确保目录存在且为空
@ -1978,6 +2063,7 @@ async def _get_project_root(
# 根据项目类型处理
if project.source_type == "zip":
# 🔥 ZIP 项目:解压 ZIP 文件
await emit(f"📦 正在解压项目文件...")
from app.services.zip_storage import load_project_zip
zip_path = await load_project_zip(project.id)
@ -1987,11 +2073,14 @@ async def _get_project_root(
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(base_path)
logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}")
await emit(f"✅ ZIP 文件解压完成")
except Exception as e:
logger.error(f"Failed to extract ZIP {zip_path}: {e}")
await emit(f"❌ 解压失败: {e}", "error")
raise RuntimeError(f"无法解压项目文件: {e}")
else:
logger.warning(f"⚠️ ZIP file not found for project {project.id}")
await emit(f"❌ ZIP 文件不存在", "error")
raise RuntimeError(f"项目 ZIP 文件不存在: {project.id}")
elif project.source_type == "repository" and project.repository_url:
@ -1999,6 +2088,8 @@ async def _get_project_root(
repo_url = project.repository_url
repo_type = project.repository_type or "other"
await emit(f"🔄 正在克隆仓库: {repo_url}")
# 检查 git 是否可用(使用 git --version 更可靠)
try:
git_check = subprocess.run(
@ -2008,11 +2099,14 @@ async def _get_project_root(
timeout=10
)
if git_check.returncode != 0:
await emit(f"❌ Git 未安装", "error")
raise RuntimeError("Git 未安装,无法克隆仓库。请在 Docker 容器中安装 git。")
logger.debug(f"Git version: {git_check.stdout.strip()}")
except FileNotFoundError:
await emit(f"❌ Git 未安装", "error")
raise RuntimeError("Git 未安装,无法克隆仓库。请在 Docker 容器中安装 git。")
except subprocess.TimeoutExpired:
await emit(f"❌ Git 检测超时", "error")
raise RuntimeError("Git 检测超时")
# 构建带认证的 URL用于私有仓库
@ -2028,6 +2122,7 @@ async def _get_project_root(
parsed.fragment
))
logger.info(f"🔐 Using GitHub token for authentication")
await emit(f"🔐 使用 GitHub Token 认证")
elif repo_type == "gitlab" and gitlab_token:
parsed = urlparse(repo_url)
auth_url = urlunparse((
@ -2039,6 +2134,7 @@ async def _get_project_root(
parsed.fragment
))
logger.info(f"🔐 Using GitLab token for authentication")
await emit(f"🔐 使用 GitLab Token 认证")
# 构建分支尝试顺序
branches_to_try = []
@ -2061,6 +2157,7 @@ async def _get_project_root(
os.makedirs(base_path, exist_ok=True)
logger.info(f"🔄 Trying to clone repository (branch: {branch})...")
await emit(f"🔄 尝试克隆分支: {branch}")
try:
result = subprocess.run(
["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path],
@ -2071,18 +2168,22 @@ async def _get_project_root(
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}")
await emit(f"✅ 仓库克隆成功 (分支: {branch})")
clone_success = True
break
else:
last_error = result.stderr
logger.warning(f"Failed to clone branch {branch}: {last_error[:200]}")
await emit(f"⚠️ 分支 {branch} 克隆失败,尝试其他分支...", "warning")
except subprocess.TimeoutExpired:
last_error = f"克隆分支 {branch} 超时"
logger.warning(last_error)
await emit(f"⚠️ 分支 {branch} 克隆超时,尝试其他分支...", "warning")
# 如果所有分支都失败,尝试不指定分支克隆(使用仓库默认分支)
if not clone_success:
logger.info(f"🔄 Trying to clone without specifying branch...")
await emit(f"🔄 尝试使用仓库默认分支克隆...")
if os.path.exists(base_path) and os.listdir(base_path):
shutil.rmtree(base_path)
os.makedirs(base_path, exist_ok=True)
@ -2097,11 +2198,13 @@ async def _get_project_root(
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}")
await emit(f"✅ 仓库克隆成功 (默认分支)")
clone_success = True
else:
last_error = result.stderr
except subprocess.TimeoutExpired:
last_error = "克隆仓库超时"
await emit(f"⚠️ 克隆超时", "warning")
if not clone_success:
# 分析错误原因
@ -2118,12 +2221,15 @@ async def _get_project_root(
error_msg = f"克隆仓库失败: {last_error[:200]}"
logger.error(f"{error_msg}")
await emit(f"{error_msg}", "error")
raise RuntimeError(error_msg)
# 验证目录不为空
if not os.listdir(base_path):
await emit(f"❌ 项目目录为空", "error")
raise RuntimeError(f"项目目录为空,可能是克隆/解压失败: {base_path}")
await emit(f"📁 项目准备完成: {base_path}")
return base_path

View File

@ -169,7 +169,7 @@ TOOL_USAGE_GUIDE = """
#### 第一步快速侦察5%时间)
```
Action: list_files
Action Input: {"path": "."}
Action Input: {"directory": "."}
```
了解项目结构技术栈入口点

View File

@ -502,6 +502,10 @@ class ListFilesTool(AgentTool):
) -> ToolResult:
"""执行文件列表"""
try:
# 🔥 兼容性处理:支持 path 参数作为 directory 的别名
if "path" in kwargs and kwargs["path"]:
directory = kwargs["path"]
target_dir = os.path.normpath(os.path.join(self.project_root, directory))
if not target_dir.startswith(os.path.normpath(self.project_root)):
return ToolResult(

View File

@ -1,11 +1,23 @@
"""
RAG (Retrieval-Augmented Generation) 系统
用于代码索引和语义检索
🔥 v2.0 改进
- 支持嵌入模型变更检测和自动重建
- 支持增量索引更新基于文件 hash
- 支持索引版本控制和状态查询
"""
from .splitter import CodeSplitter, CodeChunk
from .embeddings import EmbeddingService
from .indexer import CodeIndexer
from .indexer import (
CodeIndexer,
IndexingProgress,
IndexingResult,
IndexStatus,
IndexUpdateMode,
INDEX_VERSION,
)
from .retriever import CodeRetriever
__all__ = [
@ -14,5 +26,10 @@ __all__ = [
"EmbeddingService",
"CodeIndexer",
"CodeRetriever",
"IndexingProgress",
"IndexingResult",
"IndexStatus",
"IndexUpdateMode",
"INDEX_VERSION",
]

File diff suppressed because it is too large Load Diff

View File

@ -188,22 +188,34 @@ class TreeSitterParser:
},
}
# tree-sitter-languages 支持的语言列表
SUPPORTED_LANGUAGES = {
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
"c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash",
"json", "yaml", "html", "css", "sql", "markdown",
}
def __init__(self):
self._parsers: Dict[str, Any] = {}
self._initialized = False
def _ensure_initialized(self, language: str) -> bool:
"""确保语言解析器已初始化"""
if language in self._parsers:
return True
# 检查语言是否受支持
if language not in self.SUPPORTED_LANGUAGES:
# 不是 tree-sitter 支持的语言,静默跳过
return False
try:
from tree_sitter_languages import get_parser, get_language
from tree_sitter_languages import get_parser
parser = get_parser(language)
self._parsers[language] = parser
return True
except ImportError:
logger.warning("tree-sitter-languages not installed, falling back to regex parsing")
return False

View File

@ -60,7 +60,8 @@ dependencies = [
"chromadb>=0.4.22",
# ============ Code Parsing ============
"tree-sitter>=0.21.0",
# tree-sitter-languages 1.10.x 与 tree-sitter 0.22+ 不兼容
"tree-sitter>=0.21.0,<0.22.0",
"tree-sitter-languages>=1.10.0",
"pygments>=2.17.0",

View File

@ -47,7 +47,8 @@ langgraph>=0.0.40
chromadb>=0.4.22
# ============ Code Parsing ============
tree-sitter>=0.21.0
# tree-sitter-languages 1.10.x 与 tree-sitter 0.22+ 不兼容
tree-sitter>=0.21.0,<0.22.0
tree-sitter-languages>=1.10.0
pygments>=2.17.0

View File

@ -392,19 +392,33 @@ function AgentAuditPageContent() {
setCurrentAgentName(event.metadata.agent_name);
}
const dispatchEvents = ['dispatch', 'dispatch_complete', 'node_start', 'phase_start'];
const dispatchEvents = ['dispatch', 'dispatch_complete', 'node_start', 'phase_start', 'phase_complete'];
if (dispatchEvents.includes(event.type)) {
if (event.type === 'dispatch' || event.type === 'dispatch_complete') {
dispatch({
type: 'ADD_LOG',
payload: {
type: 'dispatch',
title: event.message || `Agent dispatch: ${event.metadata?.agent || 'unknown'}`,
agentName: getCurrentAgentName() || undefined,
}
});
}
// 所有 dispatch 类型事件都添加到日志
dispatch({
type: 'ADD_LOG',
payload: {
type: 'dispatch',
title: event.message || `Agent dispatch: ${event.metadata?.agent || 'unknown'}`,
agentName: getCurrentAgentName() || undefined,
}
});
debouncedLoadAgentTree();
return;
}
// 🔥 处理 info、warning、error 类型事件(克隆进度、索引进度等)
const infoEvents = ['info', 'warning', 'error', 'progress'];
if (infoEvents.includes(event.type)) {
dispatch({
type: 'ADD_LOG',
payload: {
type: event.type === 'error' ? 'error' : 'info',
title: event.message || event.type,
agentName: getCurrentAgentName() || undefined,
}
});
return;
}
},
onThinkingStart: () => {