feat(RAG): 添加索引任务取消检查并支持目标文件过滤
在索引过程中添加取消检查功能,允许在嵌入批处理时取消任务 支持通过target_files参数限制索引范围
This commit is contained in:
parent
e0689245de
commit
96560e6474
|
|
@ -726,14 +726,20 @@ async def _initialize_tools(
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to emit embedding progress: {e}")
|
||||
|
||||
# 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态
|
||||
def check_cancelled() -> bool:
|
||||
return task_id is not None and is_task_cancelled(task_id)
|
||||
|
||||
async for progress in indexer.smart_index_directory(
|
||||
directory=project_root,
|
||||
exclude_patterns=exclude_patterns or [],
|
||||
include_patterns=target_files, # 🔥 传递 target_files 限制索引范围
|
||||
update_mode=IndexUpdateMode.SMART,
|
||||
embedding_progress_callback=on_embedding_progress,
|
||||
cancel_check=check_cancelled, # 🔥 传递取消检查函数
|
||||
):
|
||||
# 🔥 在索引过程中检查取消状态
|
||||
if task_id and is_task_cancelled(task_id):
|
||||
if check_cancelled():
|
||||
logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}")
|
||||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
|
|
|
|||
|
|
@ -962,7 +962,7 @@ class CodeIndexer:
|
|||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||
yield progress
|
||||
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback)
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
|
||||
# 更新 collection 元数据
|
||||
project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest()
|
||||
|
|
@ -983,6 +983,7 @@ class CodeIndexer:
|
|||
progress: IndexingProgress,
|
||||
progress_callback: Optional[Callable[[IndexingProgress], None]],
|
||||
embedding_progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
cancel_check: Optional[Callable[[], bool]] = None,
|
||||
) -> AsyncGenerator[IndexingProgress, None]:
|
||||
"""增量索引"""
|
||||
logger.info("📝 开始增量索引...")
|
||||
|
|
@ -1099,7 +1100,7 @@ class CodeIndexer:
|
|||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||
yield progress
|
||||
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback)
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
|
||||
# 更新 collection 元数据
|
||||
# 移除已删除文件的 hash
|
||||
|
|
|
|||
Loading…
Reference in New Issue