feat(RAG): 添加索引任务取消检查并支持目标文件过滤
在索引过程中添加取消检查功能,允许在嵌入批处理时取消任务 支持通过target_files参数限制索引范围
This commit is contained in:
parent
e0689245de
commit
96560e6474
|
|
@ -726,14 +726,20 @@ async def _initialize_tools(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to emit embedding progress: {e}")
|
logger.warning(f"Failed to emit embedding progress: {e}")
|
||||||
|
|
||||||
|
# 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态
|
||||||
|
def check_cancelled() -> bool:
|
||||||
|
return task_id is not None and is_task_cancelled(task_id)
|
||||||
|
|
||||||
async for progress in indexer.smart_index_directory(
|
async for progress in indexer.smart_index_directory(
|
||||||
directory=project_root,
|
directory=project_root,
|
||||||
exclude_patterns=exclude_patterns or [],
|
exclude_patterns=exclude_patterns or [],
|
||||||
|
include_patterns=target_files, # 🔥 传递 target_files 限制索引范围
|
||||||
update_mode=IndexUpdateMode.SMART,
|
update_mode=IndexUpdateMode.SMART,
|
||||||
embedding_progress_callback=on_embedding_progress,
|
embedding_progress_callback=on_embedding_progress,
|
||||||
|
cancel_check=check_cancelled, # 🔥 传递取消检查函数
|
||||||
):
|
):
|
||||||
# 🔥 在索引过程中检查取消状态
|
# 🔥 在索引过程中检查取消状态
|
||||||
if task_id and is_task_cancelled(task_id):
|
if check_cancelled():
|
||||||
logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}")
|
logger.info(f"[Cancel] RAG indexing cancelled for task {task_id}")
|
||||||
raise asyncio.CancelledError("任务已取消")
|
raise asyncio.CancelledError("任务已取消")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -962,7 +962,7 @@ class CodeIndexer:
|
||||||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||||
yield progress
|
yield progress
|
||||||
|
|
||||||
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback)
|
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||||
|
|
||||||
# 更新 collection 元数据
|
# 更新 collection 元数据
|
||||||
project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest()
|
project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest()
|
||||||
|
|
@ -983,6 +983,7 @@ class CodeIndexer:
|
||||||
progress: IndexingProgress,
|
progress: IndexingProgress,
|
||||||
progress_callback: Optional[Callable[[IndexingProgress], None]],
|
progress_callback: Optional[Callable[[IndexingProgress], None]],
|
||||||
embedding_progress_callback: Optional[Callable[[int, int], None]] = None,
|
embedding_progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||||
|
cancel_check: Optional[Callable[[], bool]] = None,
|
||||||
) -> AsyncGenerator[IndexingProgress, None]:
|
) -> AsyncGenerator[IndexingProgress, None]:
|
||||||
"""增量索引"""
|
"""增量索引"""
|
||||||
logger.info("📝 开始增量索引...")
|
logger.info("📝 开始增量索引...")
|
||||||
|
|
@ -1099,7 +1100,7 @@ class CodeIndexer:
|
||||||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||||
yield progress
|
yield progress
|
||||||
|
|
||||||
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback)
|
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||||
|
|
||||||
# 更新 collection 元数据
|
# 更新 collection 元数据
|
||||||
# 移除已删除文件的 hash
|
# 移除已删除文件的 hash
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue