feat: Enhance embedding service with concurrency control, dynamic batching, and retry logic, and improve indexer with concurrent, incremental file processing.
This commit is contained in:
parent
b8e5c96541
commit
969d899476
|
|
@ -788,21 +788,9 @@ async def _initialize_tools(
|
|||
last_embedding_progress = [0] # 使用列表以便在闭包中修改
|
||||
embedding_total = [0] # 记录总数
|
||||
|
||||
# 🔥 嵌入进度回调函数(同步,但会调度异步任务)
|
||||
# 每个文件索引时不再发送单独的嵌入进度日志,避免日志爆炸
|
||||
def on_embedding_progress(processed: int, total: int):
|
||||
embedding_total[0] = total
|
||||
# 每处理 50 个或完成时更新
|
||||
if processed - last_embedding_progress[0] >= 50 or processed == total:
|
||||
last_embedding_progress[0] = processed
|
||||
percentage = (processed / total * 100) if total > 0 else 0
|
||||
msg = f"🔢 嵌入进度: {processed}/{total} ({percentage:.0f}%)"
|
||||
logger.info(msg)
|
||||
# 使用 asyncio.create_task 调度异步 emit
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(emit(msg))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to emit embedding progress: {e}")
|
||||
pass
|
||||
|
||||
# 🔥 创建取消检查函数,用于在嵌入批处理中检查取消状态
|
||||
def check_cancelled() -> bool:
|
||||
|
|
@ -822,8 +810,8 @@ async def _initialize_tools(
|
|||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
index_progress = progress
|
||||
# 每处理 10 个文件或有重要变化时发送进度更新
|
||||
if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files:
|
||||
# 🔥 逐个文件更新进度 (满足用户需求)
|
||||
if progress.processed_files - last_progress_update >= 1 or progress.processed_files == progress.total_files:
|
||||
if progress.total_files > 0:
|
||||
await emit(
|
||||
f"📝 索引进度: {progress.processed_files}/{progress.total_files} 文件 "
|
||||
|
|
|
|||
|
|
@ -636,7 +636,14 @@ class EmbeddingService:
|
|||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model}")
|
||||
# 🔥 控制并发请求数 (RPS 限制)
|
||||
self._semaphore = asyncio.Semaphore(30)
|
||||
|
||||
# 🔥 设置默认批次大小 (对于 remote 模型,用户要求为 10)
|
||||
is_remote = self.provider.lower() in ["openai", "qwen", "azure", "cohere", "jina", "huggingface"]
|
||||
self.batch_size = 10 if is_remote else 100
|
||||
|
||||
logger.info(f"Embedding service initialized with {self.provider}/{self.model} (Batch size: {self.batch_size})")
|
||||
|
||||
def _create_provider(
|
||||
self,
|
||||
|
|
@ -755,56 +762,76 @@ class EmbeddingService:
|
|||
|
||||
# 批量处理未缓存的文本
|
||||
if uncached_texts:
|
||||
total_batches = (len(uncached_texts) + batch_size - 1) // batch_size
|
||||
processed_batches = 0
|
||||
tasks = []
|
||||
current_batch_size = batch_size or self.batch_size
|
||||
|
||||
for i in range(0, len(uncached_texts), batch_size):
|
||||
# 🔥 检查是否应该取消
|
||||
if cancel_check and cancel_check():
|
||||
logger.info(f"[Embedding] Cancelled at batch {processed_batches + 1}/{total_batches}")
|
||||
raise asyncio.CancelledError("嵌入操作已取消")
|
||||
for i in range(0, len(uncached_texts), current_batch_size):
|
||||
batch = uncached_texts[i:i + current_batch_size]
|
||||
batch_indices = uncached_indices[i:i + current_batch_size]
|
||||
tasks.append(self._process_batch_with_retry(batch, batch_indices, cancel_check))
|
||||
|
||||
batch = uncached_texts[i:i + batch_size]
|
||||
batch_indices = uncached_indices[i:i + batch_size]
|
||||
# 🔥 并发执行所有批次任务
|
||||
all_batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
try:
|
||||
results = await self._provider.embed_texts(batch)
|
||||
for i, result_list in enumerate(all_batch_results):
|
||||
batch_indices = uncached_indices[i * current_batch_size : (i + 1) * current_batch_size]
|
||||
|
||||
for idx, result in zip(batch_indices, results):
|
||||
if isinstance(result_list, Exception):
|
||||
logger.error(f"Batch processing failed: {result_list}")
|
||||
# 失败批次使用零向量
|
||||
for idx in batch_indices:
|
||||
if embeddings[idx] is None:
|
||||
embeddings[idx] = [0.0] * self.dimension
|
||||
continue
|
||||
|
||||
for idx, result in zip(batch_indices, result_list):
|
||||
embeddings[idx] = result.embedding
|
||||
|
||||
# 存入缓存
|
||||
if self.cache_enabled:
|
||||
cache_key = self._cache_key(texts[idx])
|
||||
self._cache[cache_key] = result.embedding
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 🔥 重新抛出取消异常
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Batch embedding error: {e}")
|
||||
# 对失败的使用零向量
|
||||
for idx in batch_indices:
|
||||
if embeddings[idx] is None:
|
||||
embeddings[idx] = [0.0] * self.dimension
|
||||
|
||||
processed_batches += 1
|
||||
|
||||
# 🔥 调用进度回调
|
||||
if progress_callback:
|
||||
processed_count = min(i + batch_size, len(uncached_texts))
|
||||
processed_count = min((i + 1) * current_batch_size, len(uncached_texts))
|
||||
try:
|
||||
progress_callback(processed_count, len(uncached_texts))
|
||||
except Exception as e:
|
||||
logger.warning(f"Progress callback error: {e}")
|
||||
|
||||
# 添加小延迟避免限流
|
||||
if self.provider not in ["ollama"]:
|
||||
await asyncio.sleep(0.1) # 本地不延时
|
||||
|
||||
# 确保没有 None
|
||||
return [e if e is not None else [0.0] * self.dimension for e in embeddings]
|
||||
|
||||
async def _process_batch_with_retry(
|
||||
self,
|
||||
batch: List[str],
|
||||
indices: List[int],
|
||||
cancel_check: Optional[callable] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[EmbeddingResult]:
|
||||
"""带重试机制的单批次处理"""
|
||||
for attempt in range(max_retries):
|
||||
if cancel_check and cancel_check():
|
||||
raise asyncio.CancelledError("嵌入操作已取消")
|
||||
|
||||
async with self._semaphore:
|
||||
try:
|
||||
return await self._provider.embed_texts(batch)
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 429 and attempt < max_retries - 1:
|
||||
# 429 限流,指数级退避
|
||||
wait_time = (2 ** attempt) + 1
|
||||
logger.warning(f"Rate limited (429), retrying in {wait_time}s... (Attempt {attempt+1}/{max_retries})")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
raise
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
raise
|
||||
return []
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
|
|
|||
|
|
@ -949,66 +949,59 @@ class CodeIndexer:
|
|||
logger.info(f"📁 发现 {len(files)} 个文件待索引")
|
||||
yield progress
|
||||
|
||||
all_chunks: List[CodeChunk] = []
|
||||
semaphore = asyncio.Semaphore(20) # 控制文件处理并发
|
||||
file_hashes: Dict[str, str] = {}
|
||||
|
||||
# 分块处理文件
|
||||
for file_path in files:
|
||||
progress.current_file = file_path
|
||||
|
||||
async def process_file(file_path: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
progress.current_file = relative_path
|
||||
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
|
||||
# 异步读取文件
|
||||
content = await asyncio.to_thread(self._read_file_sync, file_path)
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
progress.skipped_files += 1
|
||||
continue
|
||||
return
|
||||
|
||||
# 计算文件 hash
|
||||
file_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
file_hashes[relative_path] = file_hash
|
||||
|
||||
# 限制文件大小
|
||||
# 异步分块
|
||||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
chunk.metadata["file_hash"] = file_hash
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
# 立即索引该文件的代码块 (实现逐文件更新进度)
|
||||
if chunks:
|
||||
await self._index_chunks(chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
progress.total_chunks += len(chunks)
|
||||
progress.indexed_chunks += len(chunks)
|
||||
|
||||
progress.processed_files += 1
|
||||
progress.added_files += 1
|
||||
progress.total_chunks = len(all_chunks)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(progress)
|
||||
yield progress
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理文件失败 {file_path}: {e}")
|
||||
progress.errors.append(f"{file_path}: {str(e)}")
|
||||
progress.processed_files += 1
|
||||
|
||||
logger.info(f"📝 创建了 {len(all_chunks)} 个代码块")
|
||||
# 执行全量索引
|
||||
tasks = [process_file(f) for f in files]
|
||||
|
||||
# 批量嵌入和索引
|
||||
if all_chunks:
|
||||
# 🔥 发送嵌入向量生成状态
|
||||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||
# 🔥 使用 as_completed 实现真正的逐文件进度更新
|
||||
for task in asyncio.as_completed(tasks):
|
||||
await task
|
||||
yield progress
|
||||
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=False, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
|
||||
# 更新 collection 元数据
|
||||
project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest()
|
||||
await self.vector_store.update_collection_metadata({
|
||||
|
|
@ -1016,8 +1009,7 @@ class CodeIndexer:
|
|||
"file_count": len(file_hashes),
|
||||
})
|
||||
|
||||
progress.indexed_chunks = len(all_chunks)
|
||||
logger.info(f"✅ 全量索引完成: {progress.added_files} 个文件, {len(all_chunks)} 个代码块")
|
||||
logger.info(f"✅ 全量索引完成: {progress.added_files} 个文件, {progress.indexed_chunks} 个代码块")
|
||||
yield progress
|
||||
|
||||
async def _incremental_index(
|
||||
|
|
@ -1091,26 +1083,23 @@ class CodeIndexer:
|
|||
progress_callback(progress)
|
||||
yield progress
|
||||
|
||||
# 处理新增和更新的文件
|
||||
files_to_process = files_to_add | files_to_update
|
||||
all_chunks: List[CodeChunk] = []
|
||||
semaphore = asyncio.Semaphore(20)
|
||||
file_hashes: Dict[str, str] = dict(indexed_file_hashes)
|
||||
|
||||
for relative_path in files_to_process:
|
||||
async def process_incremental_file(relative_path: str):
|
||||
async with semaphore:
|
||||
file_path = current_file_map[relative_path]
|
||||
progress.current_file = relative_path
|
||||
is_update = relative_path in files_to_update
|
||||
|
||||
try:
|
||||
# 异步读取文件,避免阻塞事件循环
|
||||
content = await asyncio.to_thread(
|
||||
self._read_file_sync, file_path
|
||||
)
|
||||
# 异步读取文件
|
||||
content = await asyncio.to_thread(self._read_file_sync, file_path)
|
||||
|
||||
if not content.strip():
|
||||
progress.processed_files += 1
|
||||
progress.skipped_files += 1
|
||||
continue
|
||||
return
|
||||
|
||||
# 如果是更新,先删除旧的
|
||||
if is_update:
|
||||
|
|
@ -1124,38 +1113,43 @@ class CodeIndexer:
|
|||
if len(content) > 500000:
|
||||
content = content[:500000]
|
||||
|
||||
# 异步分块,避免 Tree-sitter 解析阻塞事件循环
|
||||
# 异步分块
|
||||
chunks = await self.splitter.split_file_async(content, relative_path)
|
||||
|
||||
# 为每个 chunk 添加 file_hash
|
||||
for chunk in chunks:
|
||||
chunk.metadata["file_hash"] = file_hash
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
# 立即索引该文件
|
||||
if chunks:
|
||||
await self._index_chunks(chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
progress.total_chunks += len(chunks)
|
||||
progress.indexed_chunks += len(chunks)
|
||||
|
||||
progress.processed_files += 1
|
||||
if is_update:
|
||||
progress.updated_files += 1
|
||||
else:
|
||||
progress.added_files += 1
|
||||
progress.total_chunks += len(chunks)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(progress)
|
||||
yield progress
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理文件失败 {file_path}: {e}")
|
||||
progress.errors.append(f"{file_path}: {str(e)}")
|
||||
progress.processed_files += 1
|
||||
|
||||
# 批量嵌入和索引新的代码块
|
||||
if all_chunks:
|
||||
# 🔥 发送嵌入向量生成状态
|
||||
progress.status_message = f"🔢 生成 {len(all_chunks)} 个代码块的嵌入向量..."
|
||||
yield progress
|
||||
# 处理新增和更新的文件
|
||||
files_to_process = files_to_add | files_to_update
|
||||
|
||||
await self._index_chunks(all_chunks, progress, use_upsert=True, embedding_progress_callback=embedding_progress_callback, cancel_check=cancel_check)
|
||||
# 执行增量索引
|
||||
tasks = [process_incremental_file(p) for p in files_to_process]
|
||||
|
||||
# 🔥 使用 as_completed 实现真正的逐文件进度更新
|
||||
for task in asyncio.as_completed(tasks):
|
||||
await task
|
||||
yield progress
|
||||
|
||||
# 更新 collection 元数据
|
||||
# 移除已删除文件的 hash
|
||||
|
|
@ -1168,7 +1162,6 @@ class CodeIndexer:
|
|||
"file_count": len(file_hashes),
|
||||
})
|
||||
|
||||
progress.indexed_chunks = len(all_chunks)
|
||||
logger.info(
|
||||
f"✅ 增量索引完成: 新增 {progress.added_files}, "
|
||||
f"更新 {progress.updated_files}, 删除 {progress.deleted_files}"
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ services:
|
|||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
# - ./backend/app:/app/app:ro # 挂载代码目录,修改后自动生效
|
||||
- ./backend/app:/app/app:ro # 挂载代码目录,修改后自动生效
|
||||
- backend_uploads:/app/uploads
|
||||
- chroma_data:/app/data/vector_db
|
||||
- /var/run/docker.sock:/var/run/docker.sock # 沙箱执行必须
|
||||
|
|
@ -52,7 +52,10 @@ services:
|
|||
- SANDBOX_ENABLED=true
|
||||
- SANDBOX_IMAGE=deepaudit/sandbox:latest # 使用本地构建的沙箱镜像
|
||||
# 指定 embedding 服务地址
|
||||
- EMBEDDING_BASE_URL=http://host.docker.internal:8003/v1
|
||||
- EMBEDDING_PROVIDER=openai
|
||||
- EMBEDDING_MODEL=text-embedding-v4
|
||||
- EMBEDDING_DIMENSION=1024
|
||||
- EMBEDDING_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
# Gitea 配置
|
||||
- GITEA_HOST_URL=http://sl.vrgon.com:3000
|
||||
- GITEA_BOT_TOKEN=379a049b8d78965fdff474fc8676bca7e9c70248
|
||||
|
|
@ -62,7 +65,7 @@ services:
|
|||
redis:
|
||||
condition: service_healthy
|
||||
# 开发模式:启用 --reload 热重载
|
||||
command: sh -c ".venv/bin/alembic upgrade head && .venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8000"
|
||||
command: sh -c ".venv/bin/alembic upgrade head && .venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload"
|
||||
networks:
|
||||
- deepaudit-network
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue