From 96560e6474a164f7a46076bb039b409d3a0ab28f Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Tue, 16 Dec 2025 18:46:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(RAG):=20=E6=B7=BB=E5=8A=A0=E7=B4=A2?= =?UTF-8?q?=E5=BC=95=E4=BB=BB=E5=8A=A1=E5=8F=96=E6=B6=88=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E5=B9=B6=E6=94=AF=E6=8C=81=E7=9B=AE=E6=A0=87=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E8=BF=87=E6=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在索引过程中添加取消检查功能,允许在嵌入批处理时取消任务 支持通过target_files参数限制索引范围 --- backend/app/api/v1/endpoints/agent_tasks.py | 8 +++++++- backend/app/services/rag/indexer.py | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index d9fea88..c4fd398 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -726,14 +726,20 @@ async def _initialize_tools( except Exception as e: logger.warning(f"Failed to emit embedding progress: {e}") + # 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态 + def check_cancelled() -> bool: + return task_id is not None and is_task_cancelled(task_id) + async for progress in indexer.smart_index_directory( directory=project_root, exclude_patterns=exclude_patterns or [], + include_patterns=target_files, # 🔥 传递 target_files 限制索引范围 update_mode=IndexUpdateMode.SMART, embedding_progress_callback=on_embedding_progress, + cancel_check=check_cancelled, # 🔥 传递取消检查函数 ): # 🔥 在索引过程中检查取消状态 - if task_id and is_task_cancelled(task_id): + if check_cancelled(): logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}") raise asyncio.CancelledError("任务已取消") diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 405f1cf..2f22c60 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -962,7 +962,7 @@ class CodeIndexer: progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..." yield progress - await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback) + await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check) # 更新 collection 元数据 project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest() @@ -983,6 +983,7 @@ class CodeIndexer: progress: IndexingProgress, progress_callback: Optional[Callable[[IndexingProgress], None]], embedding_progress_callback: Optional[Callable[[int, int], None]] = None, + cancel_check: Optional[Callable[[], bool]] = None, ) -> AsyncGenerator[IndexingProgress, None]: """增量索引""" logger.info("📝 开始增量索引...") @@ -1099,7 +1100,7 @@ class CodeIndexer: progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..." yield progress - await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback) + await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check) # 更新 collection 元数据 # 移除已删除文件的 hash