diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index d9fea88..c4fd398 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -726,14 +726,20 @@ async def _initialize_tools( except Exception as e: logger.warning(f"Failed to emit embedding progress: {e}") + # 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态 + def check_cancelled() -> bool: + return task_id is not None and is_task_cancelled(task_id) + async for progress in indexer.smart_index_directory( directory=project_root, exclude_patterns=exclude_patterns or [], + include_patterns=target_files, # 🔥 传递 target_files 限制索引范围 update_mode=IndexUpdateMode.SMART, embedding_progress_callback=on_embedding_progress, + cancel_check=check_cancelled, # 🔥 传递取消检查函数 ): # 🔥 在索引过程中检查取消状态 - if task_id and is_task_cancelled(task_id): + if check_cancelled(): logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}") raise asyncio.CancelledError("任务已取消") diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 405f1cf..2f22c60 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -962,7 +962,7 @@ class CodeIndexer: progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..." yield progress - await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback) + await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check) # 更新 collection 元数据 project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest() @@ -983,6 +983,7 @@ class CodeIndexer: progress: IndexingProgress, progress_callback: Optional[Callable[[IndexingProgress], None]], embedding_progress_callback: Optional[Callable[[int, int], None]] = None, + cancel_check: Optional[Callable[[], bool]] = None, ) -> AsyncGenerator[IndexingProgress, None]: """增量索引""" logger.info("📝 开始增量索引...") @@ -1099,7 +1100,7 @@ class CodeIndexer: progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..." yield progress - await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback) + await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check) # 更新 collection 元数据 # 移除已删除文件的 hash