feat: Enhance embedding service with concurrency control, dynamic batching, and retry logic, and improve indexer with concurrent, incremental file processing.

This commit is contained in:
vinland100 2026-01-06 14:50:30 +08:00
parent b8e5c96541
commit 969d899476
4 changed files with 173 additions and 162 deletions

View File

@ -788,21 +788,9 @@ async def _initialize_tools(
last_embedding_progress = [0] # 使用列表以便在闭包中修改 last_embedding_progress = [0] # 使用列表以便在闭包中修改
embedding_total = [0] # 记录总数 embedding_total = [0] # 记录总数
# 🔥 嵌入进度回调函数(同步,但会调度异步任务) # 每个文件索引时不再发送单独的嵌入进度日志,避免日志爆炸
def on_embedding_progress(processed: int, total: int): def on_embedding_progress(processed: int, total: int):
embedding_total[0] = total pass
# 每处理 50 个或完成时更新
if processed - last_embedding_progress[0] >= 50 or processed == total:
last_embedding_progress[0] = processed
percentage = (processed / total * 100) if total > 0 else 0
msg = f"🔢 嵌入进度: {processed}/{total} ({percentage:.0f}%)"
logger.info(msg)
# 使用 asyncio.create_task 调度异步 emit
try:
loop = asyncio.get_running_loop()
loop.create_task(emit(msg))
except Exception as e:
logger.warning(f"Failed to emit embedding progress: {e}")
# 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态 # 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态
def check_cancelled() -> bool: def check_cancelled() -> bool:
@ -822,8 +810,8 @@ async def _initialize_tools(
raise asyncio.CancelledError("任务已取消") raise asyncio.CancelledError("任务已取消")
index_progress = progress index_progress = progress
# 每处理 10 个文件或有重要变化时发送进度更新 # 🔥 逐个文件更新进度 (满足用户需求)
if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files: if progress.processed_files - last_progress_update >= 1 or progress.processed_files == progress.total_files:
if progress.total_files > 0: if progress.total_files > 0:
await emit( await emit(
f"📝 索引进度: {progress.processed_files}/{progress.total_files} 文件 " f"📝 索引进度: {progress.processed_files}/{progress.total_files} 文件 "

View File

@ -636,7 +636,14 @@ class EmbeddingService:
base_url=self.base_url, base_url=self.base_url,
) )
logger.info(f"Embedding service initialized with {self.provider}/{self.model}") # 🔥 控制并发请求数 (RPS 限制)
self._semaphore = asyncio.Semaphore(30)
# 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10)
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
self.batch_size = 10 if is_remote else 100
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})")
def _create_provider( def _create_provider(
self, self,
@ -755,55 +762,75 @@ class EmbeddingService:
# 批量处理未缓存的文本 # 批量处理未缓存的文本
if uncached_texts: if uncached_texts:
total_batches = (len(uncached_texts) + batch_size - 1) // batch_size tasks = []
processed_batches = 0 current_batch_size = batch_size or self.batch_size
for i in range(0, len(uncached_texts), batch_size): for i in range(0, len(uncached_texts), current_batch_size):
# 🔥 检查是否应该取消 batch = uncached_texts[i:i + current_batch_size]
if cancel_check and cancel_check(): batch_indices = uncached_indices[i:i + current_batch_size]
logger.info(f"[Embedding] Cancelled at batch {processed_batches + 1}/{total_batches}") tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check))
raise asyncio.CancelledError("嵌入操作已取消")
batch = uncached_texts[i:i + batch_size] # 🔥 并发执行所有批次任务
batch_indices = uncached_indices[i:i + batch_size] all_batch_results = await asyncio.gather(*tasks, return_exceptions=True)
try: for i, result_list in enumerate(all_batch_results):
results = await self._provider.embed_texts(batch) batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size]
for idx, result in zip(batch_indices, results): if isinstance(result_list, Exception):
embeddings[idx] = result.embedding logger.error(f"Batch processing failed: {result_list}")
# 失败批次使用零向量
# 存入缓存
if self.cache_enabled:
cache_key = self._cache_key(texts[idx])
self._cache[cache_key] = result.embedding
except asyncio.CancelledError:
# 🔥 重新抛出取消异常
raise
except Exception as e:
logger.error(f"Batch embedding error: {e}")
# 对失败的使用零向量
for idx in batch_indices: for idx in batch_indices:
if embeddings[idx] is None: if embeddings[idx] is None:
embeddings[idx] = [0.0] * self.dimension embeddings[idx] = [0.0] * self.dimension
continue
processed_batches += 1 for idx, result in zip(batch_indices, result_list):
embeddings[idx] = result.embedding
# 存入缓存
if self.cache_enabled:
cache_key = self._cache_key(texts[idx])
self._cache[cache_key] = result.embedding
# 🔥 调用进度回调 # 🔥 调用进度回调
if progress_callback: if progress_callback:
processed_count = min(i + batch_size, len(uncached_texts)) processed_count = min((i + 1) * current_batch_size, len(uncached_texts))
try: try:
progress_callback(processed_count, len(uncached_texts)) progress_callback(processed_count, len(uncached_texts))
except Exception as e: except Exception as e:
logger.warning(f"Progress callback error: {e}") logger.warning(f"Progress callback error: {e}")
# 添加小延迟避免限流
if self.provider not in ["ollama"]:
await asyncio.sleep(0.1) # 本地不延时
# 确保没有 None # 确保没有 None
return [e if e is not None else [0.0] * self.dimension for e in embeddings] return [e if e is not None else [0.0] * self.dimension for e in embeddings]
async def _process_batch_with_retry(
self,
batch: List[str],
indices: List[int],
cancel_check: Optional[callable] = None,
max_retries: int = 3
) -> List[EmbeddingResult]:
"""带重试机制的单批次处理"""
for attempt in range(max_retries):
if cancel_check and cancel_check():
raise asyncio.CancelledError("嵌入操作已取消")
async with self._semaphore:
try:
return await self._provider.embed_texts(batch)
except httpx.HTTPStatusError as e:
if e.response.status_code == 429 and attempt < max_retries - 1:
# 429 限流,指数级退避
wait_time = (2 ** attempt) + 1
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})")
await asyncio.sleep(wait_time)
continue
raise
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(1)
continue
raise
return []
def clear_cache(self): def clear_cache(self):
"""清空缓存""" """清空缓存"""

View File

@ -949,66 +949,59 @@ class CodeIndexer:
logger.info(f"📁 发现 {len(files)} 个文件待索引") logger.info(f"📁 发现 {len(files)} 个文件待索引")
yield progress yield progress
all_chunks: List[CodeChunk] = [] semaphore = asyncio.Semaphore(20) # 控制文件处理并发
file_hashes: Dict[str, str] = {} file_hashes: Dict[str, str] = {}
# 分块处理文件 async def process_file(file_path: str):
for file_path in files: async with semaphore:
progress.current_file = file_path try:
relative_path = os.path.relpath(file_path, directory)
progress.current_file = relative_path
try: # 异步读取文件
relative_path = os.path.relpath(file_path, directory) content = await asyncio.to_thread(self._read_file_sync, file_path)
if not content.strip():
progress.processed_files += 1
progress.skipped_files += 1
return
# 异步读取文件,避免阻塞事件循环 # 计算文件 hash
content = await asyncio.to_thread( file_hash = hashlib.md5(content.encode()).hexdigest()
self._read_file_sync, file_path file_hashes[relative_path] = file_hash
)
# 异步分块
if len(content) > 500000:
content = content[:500000]
chunks = await self.splitter.split_file_async(content, relative_path)
for chunk in chunks:
chunk.metadata["file_hash"] = file_hash
# 立即索引该文件的代码块 (实现逐文件更新进度)
if chunks:
await self._index_chunks(chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
progress.total_chunks += len(chunks)
progress.indexed_chunks += len(chunks)
if not content.strip():
progress.processed_files += 1 progress.processed_files += 1
progress.skipped_files += 1 progress.added_files += 1
continue
if progress_callback:
progress_callback(progress)
except Exception as e:
logger.warning(f"处理文件失败 {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
# 计算文件 hash # 执行全量索引
file_hash = hashlib.md5(content.encode()).hexdigest() tasks = [process_file(f) for f in files]
file_hashes[relative_path] = file_hash
# 🔥 使用 as_completed 实现真正的逐文件进度更新
# 限制文件大小 for task in asyncio.as_completed(tasks):
if len(content) > 500000: await task
content = content[:500000]
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
chunk.metadata["file_hash"] = file_hash
all_chunks.extend(chunks)
progress.processed_files += 1
progress.added_files += 1
progress.total_chunks = len(all_chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"处理文件失败 {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
logger.info(f"📝 创建了 {len(all_chunks)} 个代码块")
# 批量嵌入和索引
if all_chunks:
# 🔥 发送嵌入向量生成状态
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
yield progress yield progress
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
# 更新 collection 元数据 # 更新 collection 元数据
project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest() project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest()
await self.vector_store.update_collection_metadata({ await self.vector_store.update_collection_metadata({
@ -1016,8 +1009,7 @@ class CodeIndexer:
"file_count": len(file_hashes), "file_count": len(file_hashes),
}) })
progress.indexed_chunks = len(all_chunks) logger.info(f"✅ 全量索引完成: {progress.added_files} 个文件, {progress.indexed_chunks} 个代码块")
logger.info(f"✅ 全量索引完成: {progress.added_files} 个文件, {len(all_chunks)} 个代码块")
yield progress yield progress
async def _incremental_index( async def _incremental_index(
@ -1091,72 +1083,74 @@ class CodeIndexer:
progress_callback(progress) progress_callback(progress)
yield progress yield progress
# 处理新增和更新的文件 semaphore = asyncio.Semaphore(20)
files_to_process = files_to_add | files_to_update
all_chunks: List[CodeChunk] = []
file_hashes: Dict[str, str] = dict(indexed_file_hashes) file_hashes: Dict[str, str] = dict(indexed_file_hashes)
for relative_path in files_to_process: async def process_incremental_file(relative_path: str):
file_path = current_file_map[relative_path] async with semaphore:
progress.current_file = relative_path file_path = current_file_map[relative_path]
is_update = relative_path in files_to_update progress.current_file = relative_path
is_update = relative_path in files_to_update
try: try:
# 异步读取文件,避免阻塞事件循环 # 异步读取文件
content = await asyncio.to_thread( content = await asyncio.to_thread(self._read_file_sync, file_path)
self._read_file_sync, file_path
) if not content.strip():
progress.processed_files += 1
progress.skipped_files += 1
return
# 如果是更新,先删除旧的
if is_update:
await self.vector_store.delete_by_file_path(relative_path)
# 计算文件 hash
file_hash = hashlib.md5(content.encode()).hexdigest()
file_hashes[relative_path] = file_hash
# 限制文件大小
if len(content) > 500000:
content = content[:500000]
# 异步分块
chunks = await self.splitter.split_file_async(content, relative_path)
# 为每个 chunk 添加 file_hash
for chunk in chunks:
chunk.metadata["file_hash"] = file_hash
# 立即索引该文件
if chunks:
await self._index_chunks(chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
progress.total_chunks += len(chunks)
progress.indexed_chunks += len(chunks)
if not content.strip():
progress.processed_files += 1 progress.processed_files += 1
progress.skipped_files += 1 if is_update:
continue progress.updated_files += 1
else:
progress.added_files += 1
# 如果是更新,先删除旧的 if progress_callback:
if is_update: progress_callback(progress)
await self.vector_store.delete_by_file_path(relative_path)
# 计算文件 hash except Exception as e:
file_hash = hashlib.md5(content.encode()).hexdigest() logger.warning(f"处理文件失败 {file_path}: {e}")
file_hashes[relative_path] = file_hash progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
# 限制文件大小 # 处理新增和更新的文件
if len(content) > 500000: files_to_process = files_to_add | files_to_update
content = content[:500000]
# 执行增量索引
# 异步分块,避免 Tree-sitter 解析阻塞事件循环 tasks = [process_incremental_file(p) for p in files_to_process]
chunks = await self.splitter.split_file_async(content, relative_path)
# 🔥 使用 as_completed 实现真正的逐文件进度更新
# 为每个 chunk 添加 file_hash for task in asyncio.as_completed(tasks):
for chunk in chunks: await task
chunk.metadata["file_hash"] = file_hash
all_chunks.extend(chunks)
progress.processed_files += 1
if is_update:
progress.updated_files += 1
else:
progress.added_files += 1
progress.total_chunks += len(chunks)
if progress_callback:
progress_callback(progress)
yield progress
except Exception as e:
logger.warning(f"处理文件失败 {file_path}: {e}")
progress.errors.append(f"{file_path}: {str(e)}")
progress.processed_files += 1
# 批量嵌入和索引新的代码块
if all_chunks:
# 🔥 发送嵌入向量生成状态
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
yield progress yield progress
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
# 更新 collection 元数据 # 更新 collection 元数据
# 移除已删除文件的 hash # 移除已删除文件的 hash
for relative_path in files_to_delete: for relative_path in files_to_delete:
@ -1168,7 +1162,6 @@ class CodeIndexer:
"file_count": len(file_hashes), "file_count": len(file_hashes),
}) })
progress.indexed_chunks = len(all_chunks)
logger.info( logger.info(
f"✅ 增量索引完成: 新增 {progress.added_files}, " f"✅ 增量索引完成: 新增 {progress.added_files}, "
f"更新 {progress.updated_files}, 删除 {progress.deleted_files}" f"更新 {progress.updated_files}, 删除 {progress.deleted_files}"

View File

@ -37,7 +37,7 @@ services:
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
volumes: volumes:
# - ./backend/app:/app/app:ro # 挂载代码目录,修改后自动生效 - ./backend/app:/app/app:ro # 挂载代码目录,修改后自动生效
- backend_uploads:/app/uploads - backend_uploads:/app/uploads
- chroma_data:/app/data/vector_db - chroma_data:/app/data/vector_db
- /var/run/docker.sock:/var/run/docker.sock # 沙箱执行必须 - /var/run/docker.sock:/var/run/docker.sock # 沙箱执行必须
@ -52,7 +52,10 @@ services:
- SANDBOX_ENABLED=true - SANDBOX_ENABLED=true
- SANDBOX_IMAGE=deepaudit/sandbox:latest # 使用本地构建的沙箱镜像 - SANDBOX_IMAGE=deepaudit/sandbox:latest # 使用本地构建的沙箱镜像
# 指定 embedding 服务地址 # 指定 embedding 服务地址
- EMBEDDING_BASE_URL=http://host.docker.internal:8003/v1 - EMBEDDING_PROVIDER=openai
- EMBEDDING_MODEL=text-embedding-v4
- EMBEDDING_DIMENSION=1024
- EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
# Gitea 配置 # Gitea 配置
- GITEA_HOST_URL=http://sl.vrgon.com:3000 - GITEA_HOST_URL=http://sl.vrgon.com:3000
- GITEA_BOT_TOKEN=379a049b8d78965fdff474fc8676bca7e9c70248 - GITEA_BOT_TOKEN=379a049b8d78965fdff474fc8676bca7e9c70248
@ -62,7 +65,7 @@ services:
redis: redis:
condition: service_healthy condition: service_healthy
# 开发模式:启用 --reload 热重载 # 开发模式:启用 --reload 热重载
command: sh -c ".venv/bin/alembic upgrade head && .venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8000" command: sh -c ".venv/bin/alembic upgrade head && .venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload"
networks: networks:
- deepaudit-network - deepaudit-network